Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS S3: Add getObjectByRanges to S3 API #2982

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 134 additions & 4 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

package akka.stream.alpakka.s3.impl

import java.net.InetSocketAddress
import java.time.{Instant, ZoneOffset, ZonedDateTime}
import scala.annotation.nowarn
import akka.actor.ActorSystem
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.http.scaladsl.Http.OutgoingConnection
import akka.http.scaladsl.model.StatusCodes.{NoContent, NotFound, OK}
import akka.http.scaladsl.model.headers.ByteRange.FromOffset
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model.{headers => http, _}
import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings}
Expand All @@ -26,7 +24,11 @@ import akka.util.ByteString
import akka.{Done, NotUsed}
import software.amazon.awssdk.regions.Region

import scala.collection.immutable
import java.net.InetSocketAddress
import java.time.{Instant, ZoneOffset, ZonedDateTime}
import scala.annotation.{nowarn, tailrec}
import scala.collection.mutable.ListBuffer
import scala.collection.{immutable, mutable}
import scala.concurrent.{Future, Promise}
import scala.util.{Failure, Success, Try}

Expand All @@ -37,6 +39,9 @@ import scala.util.{Failure, Success, Try}
BucketAndKey.validateObjectKey(key, conf)
this
}

def mkString: String =
s"s3://$bucket/$key"
}

/** Internal Api */
Expand Down Expand Up @@ -165,9 +170,12 @@ import scala.util.{Failure, Success, Try}
import Marshalling._

val MinChunkSize: Int = 5 * 1024 * 1024 //in bytes
val DefaultByteRangeSize: Long = 8 * 1024 * 1024
val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] =
Flow[ByteString].orElse(Source.single(ByteString.empty))

private val RangeEndMarker = "$END$"

// def because tokens can expire
private def signingKey(implicit settings: S3Settings) = {
val requestDate = ZonedDateTime.now(ZoneOffset.UTC)
Expand Down Expand Up @@ -232,6 +240,128 @@ import scala.util.{Failure, Success, Try}
.mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic))
}

def getObjectByRanges(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
rangeSize: Long = DefaultByteRangeSize,
parallelism: Int = 4
): Source[ByteString, Future[ObjectMetadata]] = {
Source.fromMaterializer { (_, _) =>
val objectMetadataMat = Promise[ObjectMetadata]()
getObjectMetadata(s3Location.bucket, s3Location.key, versionId, s3Headers)
.flatMapConcat {
case Some(s3Meta) if s3Meta.contentLength == 0 =>
objectMetadataMat.success(s3Meta)
Source.empty[ByteString]
case Some(s3Meta) =>
objectMetadataMat.success(s3Meta)
doGetByRanges(s3Location, versionId, s3Headers, s3Meta.contentLength, rangeSize, parallelism)
case None =>
val exc = new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]")
objectMetadataMat.failure(exc)
Source.failed(exc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will close the downstream ASAP, do you want to defer it?

Copy link
Contributor Author

@gael-ft gael-ft Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens after getObjectMetadata result has been pulled so I am not sure of what that implies in this context ?

The idea is that no ObjectMetadata means no S3 object, so I think the source should fail as well as the materialized Future.

}
.mapError {
case e: Throwable =>
objectMetadataMat.tryFailure(e)
e
}
.mapMaterializedValue(_ => objectMetadataMat.future)
}
.mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic))
}

private def doGetByRanges(
s3Location: S3Location,
versionId: Option[String],
s3Headers: S3Headers,
contentLength: Long,
rangeSize: Long,
parallelism: Int
): Source[ByteString, Any] = {
val byteRanges = computeByteRanges(contentLength, rangeSize)
if (byteRanges.size <= 1) {
getObject(s3Location, None, versionId, s3Headers)
} else {
Source(byteRanges)
.zipWithIndex
.flatMapMerge(parallelism, brToIdx => {
val (br, idx) = brToIdx
val endMarker = Source.single(ByteString("$END$"))
getObject(s3Location, Some(br), versionId, s3Headers).concat(endMarker).map(_ -> idx)
})
.statefulMapConcat(RangeMapConcat)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Composing each range-source with a buffer to allow parallel fetching, and then concatenating the resulting streams to get the resulting bytes out in the right order seems like it would achieve the same but much simpler.

Am I missing something clever that this does?

Copy link
Contributor Author

@gael-ft gael-ft May 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly, you are thinking about conflate or something similar to buffer range sources.
Something like:

getObject(s3Location, Some(br), versionId, s3Headers).conflate(_ ++ _).concat(endMarker).map(_ -> idx)

As flatMapMerge may emit in any order, I still need the range idx to order (possibly buffered) bytes.
So output item of flatMapMerge will look like (ByteString, Long) and can be in any order (regarding the Long).

How can I order them back, without statefulMapConcat ? Range2 could emit before range1 is complete and range2 could be complete before range1.

Note I am not trying to buffer "next" range, if bytes of the "next" range are pushed, I'll push them directly downstream as buffering those bytes is useless (?).


As well, regarding buffers, was not sure if it was useful to "hard pull" upstreams until parallelism * rangeSize is consumed.
Something like:

//...
  .statefulMapConcat(RangeMapConcat)
  // Might be useful to consume elements of all flatMapMerge materialized upstreams 
  .batchWeighted(parallelism * rangeSize, _.size, identity)(_ ++ _)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like

Source(byteRanges)
  .mapAsync(parallelism)(range => 
    getObjectByRanges(...).buffer(size, Backpressure)
  ).flatMapConcat(identity)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But ofc that may not be good enough with buffer sized in chunks instead of bytes, we don't have a buffer with weighted size calculation though, maybe batchWeighted could do, not sure.

Copy link
Contributor Author

@gael-ft gael-ft Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can't make it work:

Tried with:

Source(byteRanges)
  .mapAsync(parallelism)(br => Future.successful(
    getObject(s3Location, Some(br), versionId, s3Headers).batchWeighted(rangeSize, _.size, identity)(_ ++ _)
  ))
  .flatMapConcat(identity)

and

Source(byteRanges)
  .mapAsync(parallelism)(br => Future.successful(
    Source.fromMaterializer { case (mat, _) =>
      getObject(s3Location, Some(br), versionId, s3Headers)
        .preMaterialize()(mat)
        ._2
        .batchWeighted(rangeSize, _.size, identity)(_ ++ _)
    }
  ))
  .flatMapConcat(identity)

But in both situations, ranges are fetched one by one and download perf looks like getObject.
Just like if .mapAsync(P)(_ => someSource).flatMapConcat(identity) was not enough to materalize P sources at the same time.
Leaving us with the flatMapMerge ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ofc, they aren't materialized so they can start consume bytes until flatMapConcat:ed, didn't think of that. Pre-materialization creates a running source but the downstream is not materialized until it is used, so you would need to put the batching before preMaterialize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record I created an upstream issue with an idea that could make this kind of thing easier: akka/akka#31958 (continue with the current solution here though)

}
}

private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = {
require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0")
require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0")
if (contentLength <= rangeSize)
Nil
else {
val ranges = ListBuffer[ByteRange]()
for (i <- 0L until contentLength by rangeSize) {
if ((i + rangeSize) >= contentLength)
ranges += FromOffset(i)
else
ranges += ByteRange(i, i + rangeSize - 1)
}
ranges.result()
}
}

private val RangeMapConcat: () => ((ByteString, Long)) => IterableOnce[ByteString] = () => {
var currentRangeIdx = 0L
var completedRanges = Set.empty[Long]
var bufferByRangeIdx = Map.empty[Long, mutable.Queue[ByteString]]

val isEndMarker: ByteString => Boolean = bs => bs.size == RangeEndMarker.length && bs.utf8String == RangeEndMarker

def foldRangeBuffers(): Option[ByteString] = {
@tailrec
def innerFoldRangeBuffers(acc: Option[ByteString]): Option[ByteString] = {
bufferByRangeIdx.get(currentRangeIdx) match {
case None =>
if (completedRanges.contains(currentRangeIdx))
currentRangeIdx += 1
if (bufferByRangeIdx.contains(currentRangeIdx))
innerFoldRangeBuffers(acc)
else
acc
case Some(queue) =>
val next = queue.dequeueAll(_ => true).foldLeft(acc.getOrElse(ByteString.empty))(_ ++ _)
bufferByRangeIdx -= currentRangeIdx
if (completedRanges.contains(currentRangeIdx))
currentRangeIdx += 1
if (bufferByRangeIdx.contains(currentRangeIdx))
innerFoldRangeBuffers(Some(next))
else
Some(next)
}
}

innerFoldRangeBuffers(None)
}

bsToIdx => {
val (bs, idx) = bsToIdx
if (isEndMarker(bs)) {
completedRanges = completedRanges + idx
foldRangeBuffers().toList
} else if (idx == currentRangeIdx) {
bs :: Nil
} else {
bufferByRangeIdx = bufferByRangeIdx.updatedWith(idx.toInt) {
case Some(queue) => Some(queue.enqueue(bs))
case None => Some(mutable.Queue(bs))
}
Nil
}
}
}

/**
* An ADT that represents the current state of pagination
*/
Expand Down
49 changes: 49 additions & 0 deletions s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,55 @@ object S3 {
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObject(S3Location(bucket, key), range, versionId, s3Headers)

/**
* Gets a S3 Object using `Byte-Range Fetches`
*
* @param bucket the s3 bucket name
* @param key the s3 object key
* @param sse [optional] the server side encryption used on upload
* @param rangeSize size of each range to request
* @param parallelism number of range to request in parallel
*
* @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the
* [[akka.stream.alpakka.s3.ObjectMetadata]]
*/
def getObjectByRanges(
bucket: String,
key: String,
versionId: Option[String] = None,
sse: Option[ServerSideEncryption] = None,
rangeSize: Long = MinChunkSize,
parallelism: Int = 4
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObjectByRanges(S3Location(bucket, key),
versionId,
S3Headers.empty.withOptionalServerSideEncryption(sse),
rangeSize,
parallelism
)

/**
* Gets a S3 Object using `Byte-Range Fetches`
*
* @param bucket the s3 bucket name
* @param key the s3 object key
* @param s3Headers any headers you want to add
* @param rangeSize size of each range to request
* @param parallelism number of range to request in parallel
*
* @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the
* [[akka.stream.alpakka.s3.ObjectMetadata]]
*/
def getObjectByRanges(
bucket: String,
key: String,
versionId: Option[String],
s3Headers: S3Headers,
rangeSize: Long,
parallelism: Int
): Source[ByteString, Future[ObjectMetadata]] =
S3Stream.getObjectByRanges(S3Location(bucket, key), versionId, s3Headers, rangeSize, parallelism)

/**
* Will return a list containing all of the buckets for the current AWS account
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package akka.stream.alpakka.s3.scaladsl

import akka.actor.ActorSystem
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.alpakka.s3.S3Settings
import akka.stream.alpakka.s3.headers.ServerSideEncryption
import akka.stream.alpakka.s3.impl.S3Stream
Expand Down Expand Up @@ -177,6 +178,19 @@ abstract class S3WireMockBase(_system: ActorSystem, val _wireMockServer: WireMoc
)
)

def mockRangedDownload(byteRange: ByteRange, range: String): Unit =
mock
.register(
get(urlEqualTo(s"/$bucketKey"))
.withHeader("Range", new EqualToPattern(s"bytes=$byteRange"))
.willReturn(
aResponse()
.withStatus(200)
.withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""")
.withBody(range)
)
)

def mockRangedDownload(): Unit =
mock
.register(
Expand Down
38 changes: 38 additions & 0 deletions s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package docs.scaladsl

import akka.http.scaladsl.model.headers.ByteRange
import akka.http.scaladsl.model.headers.ByteRange.FromOffset
import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity, HttpResponse, IllegalUriException}
import akka.stream.Attributes
import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists}
Expand All @@ -26,6 +27,43 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec {
override protected def afterEach(): Unit =
mock.removeMappings()

"S3Source" should "download a stream of bytes by ranges from S3" in {

val bodyBytes = ByteString(body)
val bodyRanges = bodyBytes.grouped(10).toList
val rangeHeaders = bodyRanges.zipWithIndex.map {
case (_, idx) if idx != bodyRanges.size - 1 =>
ByteRange(idx * 10, (idx * 10) + 10 - 1)
case (_, idx) =>
FromOffset(idx * 10)
}
val rangesWithHeaders = bodyRanges.zip(rangeHeaders)

mockHead(bodyBytes.size)
rangesWithHeaders.foreach { case (bs, br) => mockRangedDownload(br, bs.utf8String)}

val s3Source: Source[ByteString, Future[ObjectMetadata]] =
S3.getObjectByRanges(bucket, bucketKey, rangeSize = 10L)

val (metadataFuture, dataFuture) =
s3Source.toMat(Sink.reduce[ByteString](_ ++ _))(Keep.both).run()

val data = dataFuture.futureValue
val metadata = metadataFuture.futureValue

data.utf8String shouldBe body

HttpResponse(
entity = HttpEntity(
metadata.contentType
.flatMap(ContentType.parse(_).toOption)
.getOrElse(ContentTypes.`application/octet-stream`),
metadata.contentLength,
s3Source
)
)
}

"S3Source" should "download a stream of bytes from S3" in {

mockDownload()
Expand Down