diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala index e4b2dedf2a..8bfe68479c 100644 --- a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala +++ b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala @@ -4,14 +4,12 @@ package akka.stream.alpakka.s3.impl -import java.net.InetSocketAddress -import java.time.{Instant, ZoneOffset, ZonedDateTime} -import scala.annotation.nowarn import akka.actor.ActorSystem import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts import akka.http.scaladsl.Http.OutgoingConnection import akka.http.scaladsl.model.StatusCodes.{NoContent, NotFound, OK} +import akka.http.scaladsl.model.headers.ByteRange.FromOffset import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.{headers => http, _} import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} @@ -26,7 +24,11 @@ import akka.util.ByteString import akka.{Done, NotUsed} import software.amazon.awssdk.regions.Region -import scala.collection.immutable +import java.net.InetSocketAddress +import java.time.{Instant, ZoneOffset, ZonedDateTime} +import scala.annotation.{nowarn, tailrec} +import scala.collection.mutable.ListBuffer +import scala.collection.{immutable, mutable} import scala.concurrent.{Future, Promise} import scala.util.{Failure, Success, Try} @@ -37,6 +39,9 @@ import scala.util.{Failure, Success, Try} BucketAndKey.validateObjectKey(key, conf) this } + + def mkString: String = + s"s3://$bucket/$key" } /** Internal Api */ @@ -165,9 +170,12 @@ import scala.util.{Failure, Success, Try} import Marshalling._ val MinChunkSize: Int = 5 * 1024 * 1024 //in bytes + val DefaultByteRangeSize: Long = 8 * 1024 * 1024 val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].orElse(Source.single(ByteString.empty)) + private val RangeEndMarker = "$END$" + // def because tokens can expire private def signingKey(implicit settings: S3Settings) = { val requestDate = ZonedDateTime.now(ZoneOffset.UTC) @@ -232,6 +240,128 @@ import scala.util.{Failure, Success, Try} .mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic)) } + def getObjectByRanges( + s3Location: S3Location, + versionId: Option[String], + s3Headers: S3Headers, + rangeSize: Long = DefaultByteRangeSize, + parallelism: Int = 4 + ): Source[ByteString, Future[ObjectMetadata]] = { + Source.fromMaterializer { (_, _) => + val objectMetadataMat = Promise[ObjectMetadata]() + getObjectMetadata(s3Location.bucket, s3Location.key, versionId, s3Headers) + .flatMapConcat { + case Some(s3Meta) if s3Meta.contentLength == 0 => + objectMetadataMat.success(s3Meta) + Source.empty[ByteString] + case Some(s3Meta) => + objectMetadataMat.success(s3Meta) + doGetByRanges(s3Location, versionId, s3Headers, s3Meta.contentLength, rangeSize, parallelism) + case None => + val exc = new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]") + objectMetadataMat.failure(exc) + Source.failed(exc) + } + .mapError { + case e: Throwable => + objectMetadataMat.tryFailure(e) + e + } + .mapMaterializedValue(_ => objectMetadataMat.future) + } + .mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic)) + } + + private def doGetByRanges( + s3Location: S3Location, + versionId: Option[String], + s3Headers: S3Headers, + contentLength: Long, + rangeSize: Long, + parallelism: Int + ): Source[ByteString, Any] = { + val byteRanges = computeByteRanges(contentLength, rangeSize) + if (byteRanges.size <= 1) { + getObject(s3Location, None, versionId, s3Headers) + } else { + Source(byteRanges) + .zipWithIndex + .flatMapMerge(parallelism, brToIdx => { + val (br, idx) = brToIdx + val endMarker = Source.single(ByteString("$END$")) + getObject(s3Location, Some(br), versionId, s3Headers).concat(endMarker).map(_ -> idx) + }) + .statefulMapConcat(RangeMapConcat) + } + } + + private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = { + require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0") + require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0") + if (contentLength <= rangeSize) + Nil + else { + val ranges = ListBuffer[ByteRange]() + for (i <- 0L until contentLength by rangeSize) { + if ((i + rangeSize) >= contentLength) + ranges += FromOffset(i) + else + ranges += ByteRange(i, i + rangeSize - 1) + } + ranges.result() + } + } + + private val RangeMapConcat: () => ((ByteString, Long)) => IterableOnce[ByteString] = () => { + var currentRangeIdx = 0L + var completedRanges = Set.empty[Long] + var bufferByRangeIdx = Map.empty[Long, mutable.Queue[ByteString]] + + val isEndMarker: ByteString => Boolean = bs => bs.size == RangeEndMarker.length && bs.utf8String == RangeEndMarker + + def foldRangeBuffers(): Option[ByteString] = { + @tailrec + def innerFoldRangeBuffers(acc: Option[ByteString]): Option[ByteString] = { + bufferByRangeIdx.get(currentRangeIdx) match { + case None => + if (completedRanges.contains(currentRangeIdx)) + currentRangeIdx += 1 + if (bufferByRangeIdx.contains(currentRangeIdx)) + innerFoldRangeBuffers(acc) + else + acc + case Some(queue) => + val next = queue.dequeueAll(_ => true).foldLeft(acc.getOrElse(ByteString.empty))(_ ++ _) + bufferByRangeIdx -= currentRangeIdx + if (completedRanges.contains(currentRangeIdx)) + currentRangeIdx += 1 + if (bufferByRangeIdx.contains(currentRangeIdx)) + innerFoldRangeBuffers(Some(next)) + else + Some(next) + } + } + + innerFoldRangeBuffers(None) + } + + bsToIdx => { + val (bs, idx) = bsToIdx + if (isEndMarker(bs)) { + completedRanges = completedRanges + idx + foldRangeBuffers().toList + } else if (idx == currentRangeIdx) { + bs :: Nil + } else { + bufferByRangeIdx = bufferByRangeIdx.updatedWith(idx.toInt) { + case Some(queue) => Some(queue.enqueue(bs)) + case None => Some(mutable.Queue(bs)) + } + Nil + } + } + } + /** * An ADT that represents the current state of pagination */ diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala b/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala index cacb6cbe6b..4b88b3b094 100644 --- a/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala +++ b/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala @@ -263,6 +263,55 @@ object S3 { ): Source[ByteString, Future[ObjectMetadata]] = S3Stream.getObject(S3Location(bucket, key), range, versionId, s3Headers) + /** + * Gets a S3 Object using `Byte-Range Fetches` + * + * @param bucket the s3 bucket name + * @param key the s3 object key + * @param sse [optional] the server side encryption used on upload + * @param rangeSize size of each range to request + * @param parallelism number of range to request in parallel + * + * @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the + * [[akka.stream.alpakka.s3.ObjectMetadata]] + */ + def getObjectByRanges( + bucket: String, + key: String, + versionId: Option[String] = None, + sse: Option[ServerSideEncryption] = None, + rangeSize: Long = MinChunkSize, + parallelism: Int = 4 + ): Source[ByteString, Future[ObjectMetadata]] = + S3Stream.getObjectByRanges(S3Location(bucket, key), + versionId, + S3Headers.empty.withOptionalServerSideEncryption(sse), + rangeSize, + parallelism + ) + + /** + * Gets a S3 Object using `Byte-Range Fetches` + * + * @param bucket the s3 bucket name + * @param key the s3 object key + * @param s3Headers any headers you want to add + * @param rangeSize size of each range to request + * @param parallelism number of range to request in parallel + * + * @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the + * [[akka.stream.alpakka.s3.ObjectMetadata]] + */ + def getObjectByRanges( + bucket: String, + key: String, + versionId: Option[String], + s3Headers: S3Headers, + rangeSize: Long, + parallelism: Int + ): Source[ByteString, Future[ObjectMetadata]] = + S3Stream.getObjectByRanges(S3Location(bucket, key), versionId, s3Headers, rangeSize, parallelism) + /** * Will return a list containing all of the buckets for the current AWS account * diff --git a/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala b/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala index 67f58f9fa4..242611251a 100644 --- a/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala +++ b/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala @@ -5,6 +5,7 @@ package akka.stream.alpakka.s3.scaladsl import akka.actor.ActorSystem +import akka.http.scaladsl.model.headers.ByteRange import akka.stream.alpakka.s3.S3Settings import akka.stream.alpakka.s3.headers.ServerSideEncryption import akka.stream.alpakka.s3.impl.S3Stream @@ -177,6 +178,19 @@ abstract class S3WireMockBase(_system: ActorSystem, val _wireMockServer: WireMoc ) ) + def mockRangedDownload(byteRange: ByteRange, range: String): Unit = + mock + .register( + get(urlEqualTo(s"/$bucketKey")) + .withHeader("Range", new EqualToPattern(s"bytes=$byteRange")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""") + .withBody(range) + ) + ) + def mockRangedDownload(): Unit = mock .register( diff --git a/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala b/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala index 1782746099..b72b60adf1 100644 --- a/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala +++ b/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala @@ -5,6 +5,7 @@ package docs.scaladsl import akka.http.scaladsl.model.headers.ByteRange +import akka.http.scaladsl.model.headers.ByteRange.FromOffset import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity, HttpResponse, IllegalUriException} import akka.stream.Attributes import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists} @@ -26,6 +27,43 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec { override protected def afterEach(): Unit = mock.removeMappings() + "S3Source" should "download a stream of bytes by ranges from S3" in { + + val bodyBytes = ByteString(body) + val bodyRanges = bodyBytes.grouped(10).toList + val rangeHeaders = bodyRanges.zipWithIndex.map { + case (_, idx) if idx != bodyRanges.size - 1 => + ByteRange(idx * 10, (idx * 10) + 10 - 1) + case (_, idx) => + FromOffset(idx * 10) + } + val rangesWithHeaders = bodyRanges.zip(rangeHeaders) + + mockHead(bodyBytes.size) + rangesWithHeaders.foreach { case (bs, br) => mockRangedDownload(br, bs.utf8String)} + + val s3Source: Source[ByteString, Future[ObjectMetadata]] = + S3.getObjectByRanges(bucket, bucketKey, rangeSize = 10L) + + val (metadataFuture, dataFuture) = + s3Source.toMat(Sink.reduce[ByteString](_ ++ _))(Keep.both).run() + + val data = dataFuture.futureValue + val metadata = metadataFuture.futureValue + + data.utf8String shouldBe body + + HttpResponse( + entity = HttpEntity( + metadata.contentType + .flatMap(ContentType.parse(_).toOption) + .getOrElse(ContentTypes.`application/octet-stream`), + metadata.contentLength, + s3Source + ) + ) + } + "S3Source" should "download a stream of bytes from S3" in { mockDownload()