Skip to content

Commit

Permalink
Merge pull request #143 from m3nadav/avoid_s3_redirect_response
Browse files Browse the repository at this point in the history
update s3 multipart uploads to use region-specific uri (resolves #142)
  • Loading branch information
patriknw authored Jan 13, 2017
2 parents 461a999 + d2a579a commit 93453a7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
27 changes: 15 additions & 12 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/HttpRequests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,25 @@ import scala.concurrent.{ ExecutionContext, Future }

private[alpakka] object HttpRequests {

def getDownloadRequest(s3Location: S3Location)(implicit conf: S3Settings): HttpRequest =
s3Request(s3Location)
def getDownloadRequest(s3Location: S3Location, region: String)(implicit conf: S3Settings): HttpRequest =
s3Request(s3Location, region: String)

def initiateMultipartUploadRequest(s3Location: S3Location, contentType: ContentType, cannedAcl: CannedAcl)(
implicit conf: S3Settings): HttpRequest =
s3Request(s3Location, HttpMethods.POST, _.withQuery(Query("uploads")))
def initiateMultipartUploadRequest(s3Location: S3Location, contentType: ContentType, cannedAcl: CannedAcl, region: String)(
implicit conf: S3Settings): HttpRequest =
s3Request(s3Location, region, HttpMethods.POST, _.withQuery(Query("uploads")))
.withDefaultHeaders(RawHeader("x-amz-acl", cannedAcl.value))
.withEntity(HttpEntity.empty(contentType))

def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: Source[ByteString, _], payloadSize: Int)(
def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: Source[ByteString, _], payloadSize: Int, region: String)(
implicit conf: S3Settings): HttpRequest =
s3Request(
upload.s3Location,
region,
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
).withEntity(HttpEntity(ContentTypes.`application/octet-stream`, payloadSize, payload))

def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)])(
def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)], region: String)(
implicit ec: ExecutionContext,
conf: S3Settings): Future[HttpRequest] = {

Expand All @@ -48,30 +49,32 @@ private[alpakka] object HttpRequests {
} yield {
s3Request(
upload.s3Location,
region,
HttpMethods.POST,
_.withQuery(Query("uploadId" -> upload.uploadId))
).withEntity(entity)
}
}

private[this] def s3Request(s3Location: S3Location,
region: String,
method: HttpMethod = HttpMethods.GET,
uriFn: (Uri => Uri) = identity)(implicit conf: S3Settings): HttpRequest = {

def requestHost(s3Location: S3Location)(implicit conf: S3Settings): Uri.Host =
def requestHost(s3Location: S3Location, region: String)(implicit conf: S3Settings): Uri.Host =
conf.proxy match {
case None => Uri.Host(s"${s3Location.bucket}.s3.amazonaws.com")
case None => Uri.Host(s"${s3Location.bucket}.s3-${region}.amazonaws.com")
case Some(proxy) => Uri.Host(proxy.host)
}

def requestUri(s3Location: S3Location)(implicit conf: S3Settings): Uri = {
val uri = Uri(s"/${s3Location.key}").withHost(requestHost(s3Location)).withScheme("https")
def requestUri(s3Location: S3Location, region: String)(implicit conf: S3Settings): Uri = {
val uri = Uri(s"/${s3Location.key}").withHost(requestHost(s3Location, region)).withScheme("https")
conf.proxy match {
case None => uri
case Some(proxy) => uri.withPort(proxy.port)
}
}

HttpRequest(method).withHeaders(Host(requestHost(s3Location))).withUri(uriFn(requestUri(s3Location)))
HttpRequest(method).withHeaders(Host(requestHost(s3Location, region))).withUri(uriFn(requestUri(s3Location, region)))
}
}
10 changes: 5 additions & 5 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials, region: Strin

def download(s3Location: S3Location): Source[ByteString, NotUsed] = {
import mat.executionContext
Source.fromFuture(signAndGet(HttpRequests.getDownloadRequest(s3Location)).map(_.dataBytes)).flatMapConcat(identity)
Source.fromFuture(signAndGet(HttpRequests.getDownloadRequest(s3Location, region)).map(_.dataBytes)).flatMapConcat(identity)
}

/**
Expand All @@ -83,7 +83,7 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials, region: Strin
cannedAcl: CannedAcl): Future[MultipartUpload] = {
import mat.executionContext

val req = HttpRequests.initiateMultipartUploadRequest(s3Location, contentType, cannedAcl)
val req = HttpRequests.initiateMultipartUploadRequest(s3Location, contentType, cannedAcl, region)

val response = for {
signedReq <- Signer.signedRequest(req, signingKey)
Expand All @@ -103,8 +103,8 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials, region: Strin
parts: Seq[SuccessfulUploadPart]): Future[CompleteMultipartUploadResult] = {
import mat.executionContext

for (req <- HttpRequests
.completeMultipartUploadRequest(parts.head.multipartUpload, parts.map { case p => (p.index, p.etag) });
for (req <- HttpRequests.completeMultipartUploadRequest(parts.head.multipartUpload,
parts.map { case p => (p.index, p.etag) }, region);
res <- signAndGetAs[CompleteMultipartUploadResult](req)) yield res
}

Expand Down Expand Up @@ -149,7 +149,7 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials, region: Strin
.concatSubstreams
.zipWith(requestInfo) {
case (payload, (uploadInfo, chunkIndex)) =>
(HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload.data, payload.size),
(HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload.data, payload.size, region),
(uploadInfo, chunkIndex))
}
.mapAsync(parallelism) { case (req, info) => Signer.signedRequest(req, signingKey).zip(Future.successful(info)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ class HttpRequestsSpec extends FlatSpec with Matchers {
val location = S3Location("bucket", "image-1024@2x")
val contentType = MediaTypes.`image/jpeg`
val acl = CannedAcl.PublicRead
val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, acl)
val req = HttpRequests.initiateMultipartUploadRequest(location, contentType, acl, "us-east-1")

req.entity shouldEqual HttpEntity.empty(contentType)
req.headers should contain(RawHeader("x-amz-acl", acl.value))
req.uri.authority.host.toString shouldEqual "bucket.s3-us-east-1.amazonaws.com"
}
}

0 comments on commit 93453a7

Please sign in to comment.