Skip to content

Commit

Permalink
s3: Added support for partial file download from S3 #264 (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
thereisnospoon authored and jrudolph committed Apr 20, 2017
1 parent 3823145 commit e72e9ca
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 9 deletions.
9 changes: 9 additions & 0 deletions docs/src/main/paradox/s3.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ Scala
Java
: @@snip (../../../../s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #download }

In order to download a range of a file's data you can use overloaded method which
additionally takes `ByteRange` as argument.

Scala
: @@snip (../../../../s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3SourceSpec.scala) { #rangedDownload }

Java
: @@snip (../../../../s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #rangedDownload }

### Running the example code

The code in this guide is part of runnable tests of this project. You are welcome to edit the code and run it in sbt.
Expand Down
14 changes: 10 additions & 4 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.ByteRange
import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import akka.stream.Materializer
import akka.stream.alpakka.s3.acl.CannedAcl
Expand Down Expand Up @@ -59,13 +60,18 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials,
val MinChunkSize = 5242880 //in bytes
val signingKey = SigningKey(credentials, CredentialScope(LocalDate.now(), region, "s3"))

def download(s3Location: S3Location): Source[ByteString, NotUsed] = {
def download(s3Location: S3Location, range: Option[ByteRange] = None): Source[ByteString, NotUsed] = {
import mat.executionContext
Source.fromFuture(request(s3Location).flatMap(entityForSuccess).map(_.dataBytes)).flatMapConcat(identity)
Source.fromFuture(request(s3Location, range).flatMap(entityForSuccess).map(_.dataBytes)).flatMapConcat(identity)
}

def request(s3Location: S3Location): Future[HttpResponse] =
signAndGet(getDownloadRequest(s3Location, region))
def request(s3Location: S3Location, rangeOption: Option[ByteRange] = None): Future[HttpResponse] = {
val downloadRequest = getDownloadRequest(s3Location, region)
signAndGet(rangeOption match {
case Some(range) => downloadRequest.withHeaders(headers.Range(range))
case _ => downloadRequest
})
}

/**
* Uploads a stream of ByteStrings to a specified location as a multipart upload.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import java.util.concurrent.CompletionStage
import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.impl.model.JavaUri
import akka.http.javadsl.model.headers.ByteRange
import akka.http.javadsl.model.{ContentType, HttpResponse, Uri}
import akka.http.scaladsl.model.{ContentTypes, ContentType => ScalaContentType}
import akka.http.scaladsl.model.headers.{ByteRange => ScalaByteRange}
import akka.stream.Materializer
import akka.stream.alpakka.s3.acl.CannedAcl
import akka.stream.alpakka.s3.auth.AWSCredentials
Expand All @@ -35,6 +37,11 @@ final class S3Client(credentials: AWSCredentials, region: String, system: ActorS
def download(bucket: String, key: String): Source[ByteString, NotUsed] =
impl.download(S3Location(bucket, key)).asJava

def download(bucket: String, key: String, range: ByteRange): Source[ByteString, NotUsed] = {
val scalaRange = range.asInstanceOf[ScalaByteRange]
impl.download(S3Location(bucket, key), Some(scalaRange)).asJava
}

def multipartUpload(bucket: String,
key: String,
contentType: ContentType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package akka.stream.alpakka.s3.scaladsl
import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.Materializer
import akka.stream.alpakka.s3.S3Settings
import akka.stream.alpakka.s3.acl.CannedAcl
Expand Down Expand Up @@ -42,6 +43,9 @@ final class S3Client(credentials: AWSCredentials, region: String)(implicit syste

def download(bucket: String, key: String): Source[ByteString, NotUsed] = impl.download(S3Location(bucket, key))

def download(bucket: String, key: String, range: ByteRange): Source[ByteString, NotUsed] =
impl.download(S3Location(bucket, key), Some(range))

def multipartUpload(bucket: String,
key: String,
contentType: ContentType = ContentTypes.`application/octet-stream`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
package akka.stream.alpakka.s3.javadsl;

import akka.NotUsed;
import akka.actor.ActorSystem;
import akka.http.javadsl.model.Uri;
import akka.http.javadsl.model.headers.ByteRange;
import akka.stream.ActorMaterializer;
import akka.stream.Materializer;
import akka.stream.alpakka.s3.auth.AWSCredentials;
Expand All @@ -16,10 +16,12 @@
import akka.util.ByteString;
import org.junit.Test;

import java.util.Arrays;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class S3ClientTest extends S3WireMockBase {

Expand Down Expand Up @@ -63,4 +65,22 @@ public void download() throws Exception {

assertEquals(body(), result);
}

@Test
public void rangedDownload() throws Exception {

mockRangedDownload();

//#rangedDownload
final Source<ByteString, NotUsed> source = client.download(bucket(), bucketKey(),
ByteRange.createSlice(bytesRangeStart(), bytesRangeEnd()));
//#rangedDownload

final CompletionStage<byte[]> resultCompletionStage =
source.map(ByteString::toArray).runWith(Sink.head(), materializer);

byte[] result = resultCompletionStage.toCompletableFuture().get(5, TimeUnit.SECONDS);

assertTrue(Arrays.equals(rangeOfBody(), result));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package akka.stream.alpakka.s3.scaladsl

import akka.NotUsed
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.alpakka.s3.S3Exception
import akka.stream.scaladsl.{Sink, Source}
import akka.util.ByteString
Expand All @@ -17,14 +18,28 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec {
mockDownload()

//#download
val s3Source: Source[ByteString, NotUsed] = s3Client.download("testBucket", "testKey")
val s3Source: Source[ByteString, NotUsed] = s3Client.download(bucket, bucketKey)
//#download

val result: Future[String] = s3Source.map(_.utf8String).runWith(Sink.head)

result.futureValue shouldBe body
}

it should "download a range of file's bytes from S3 if bytes range given" in {

mockRangedDownload()

//#rangedDownload
val s3Source: Source[ByteString, NotUsed] =
s3Client.download(bucket, bucketKey, ByteRange(bytesRangeStart, bytesRangeEnd))
//#rangedDownload

val result: Future[Array[Byte]] = s3Source.map(_.toArray).runWith(Sink.head)

result.futureValue shouldBe rangeOfBody
}

it should "fail if request returns 404" in {

mock404s()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.github.tomakehurst.wiremock.client.WireMock._
import com.github.tomakehurst.wiremock.core.WireMockConfiguration._
import com.typesafe.config.ConfigFactory
import S3WireMockBase._
import com.github.tomakehurst.wiremock.matching.EqualToPattern

abstract class S3WireMockBase(_system: ActorSystem, _wireMockServer: WireMockServer) extends TestKit(_system) {

Expand Down Expand Up @@ -39,15 +40,30 @@ abstract class S3WireMockBase(_system: ActorSystem, _wireMockServer: WireMockSer
val uploadId = "VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA"
val etag = "5b27a21a97fcf8a7004dd1d906e7a5ba"
val url = s"http://testbucket.s3.amazonaws.com/testKey"
val (bytesRangeStart, bytesRangeEnd) = (2, 10)
val rangeOfBody = body.getBytes.slice(bytesRangeStart, bytesRangeEnd + 1)

def mockDownload(): Unit =
mock
.register(
get(urlEqualTo("/testKey")).willReturn(
get(urlEqualTo(s"/$bucketKey")).willReturn(
aResponse().withStatus(200).withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""").withBody(body)
)
)

def mockRangedDownload(): Unit =
mock
.register(
get(urlEqualTo(s"/$bucketKey"))
.withHeader("Range", new EqualToPattern(s"bytes=$bytesRangeStart-$bytesRangeEnd"))
.willReturn(
aResponse()
.withStatus(200)
.withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""")
.withBody(rangeOfBody)
)
)

def mockUpload(): Unit = {
mock
.register(
Expand Down Expand Up @@ -120,8 +136,8 @@ private object S3WireMockBase {
val s = (Thread.currentThread.getStackTrace map (_.getClassName) drop 1)
.dropWhile(_ matches "(java.lang.Thread|.*WireMockBase.?$)")
val reduced = s.lastIndexWhere(_ == clazz.getName) match {
case -1 s
case z s drop (z + 1)
case -1 => s
case z => s drop (z + 1)
}
reduced.head.replaceFirst(""".*\.""", "").replaceAll("[^a-zA-Z_0-9]", "_")
}
Expand Down

0 comments on commit e72e9ca

Please sign in to comment.