Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a special case for how S3 requests should be signed #1605

Merged
merged 10 commits into from
Jan 2, 2025
Merged
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Thank you!
# 0.18.28

* Better support for timestamps before Linux Epoch and trimming the Timestamp nanosecond part (see [#1623](https://github.com/disneystreaming/smithy4s/pull/1623))
* Adds a special for AWS request signing when S3 is being used.

# 0.18.27

Expand All @@ -23,8 +24,6 @@ Thank you!

# 0.18.25

* Add A flag to allow for numerics to be decoded from JSON strings (in smithy4s-json).
* Fixes issues in which applications of some Smithy traits would be incorrectly rendered in Scala code (see [#1602](https://github.com/disneystreaming/smithy4s/pull/1602)).
* Fixes an issue in which refinements wouldn't work on custom simple shapes (newtypes) (see [#1595](https://github.com/disneystreaming/smithy4s/pull/1595))
* Fixes a regression from 0.18.4 which incorrectly rendered default values for certain types (see [#1593](https://github.com/disneystreaming/smithy4s/pull/1593))
* Fixes an issue in which union members targetting Unit would fail to compile when used as traits (see [#1600](https://github.com/disneystreaming/smithy4s/pull/1600)).
Expand Down
176 changes: 105 additions & 71 deletions modules/aws-http4s/src/smithy4s/aws/internals/AwsSigning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ import java.nio.charset.StandardCharsets
*/
private[aws] object AwsSigning {

// see https://raw.githubusercontent.com/awslabs/aws-sdk-kotlin/main/codegen/sdk/aws-models/s3.json
val S3 = "AmazonS3"

def middleware[F[_]: Concurrent](
awsEnvironment: AwsEnvironment[F]
): Endpoint.Middleware[Client[F]] = new Endpoint.Middleware[Client[F]] {
Expand Down Expand Up @@ -88,6 +91,12 @@ private[aws] object AwsSigning {
credentials: F[AwsCredentials],
region: F[AwsRegion]
): Request[F] => F[Request[F]] = {

// S3 has special rules, in that it expects the X-Amz-Content-SHA256 to be set.
val preSign: PreSigner[F] =
if (serviceName == S3) new PreSigner.S3InMemorySigned[F]
else new PreSigner.Standard[F]

val contentType = org.http4s.headers.`Content-Type`.headerInstance
val `Content-Type` = contentType.name

Expand All @@ -106,79 +115,78 @@ private[aws] object AwsSigning {
}

// scalafmt: { align.preset = most, danglingParentheses.preset = false, maxColumn = 240, align.tokens = [{code = ":"}]}
(request: Request[F]) => {

val bodyF = request.body.chunks.compile.to(Chunk).map(_.flatten)
val awsHeadersF = (bodyF, timestamp, credentials, region).mapN { case (body, timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
val canonicalQueryString =
if (queryParams.isEmpty) ""
else
queryParams
.map { case (k, v) =>
URLEncoder.encode(k, StandardCharsets.UTF_8.name()) + "=" + URLEncoder.encode(v, StandardCharsets.UTF_8.name())
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
`Content-Type` -> request.contentType.map(contentType.value(_)).orNull,
`Host` -> request.uri.host.map(_.renderString).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)

val canonicalHeadersString = baseHeadersList
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val payloadHash = sha256HexDigest(body.toArray)
val pathString = request.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
.append(newline)
.append(pathString)
.append(newline)
.append(canonicalQueryString)
.append(newline)
.append(canonicalHeadersString)
.append(newline)
.append(newline)
.append(signedHeadersString)
.append(newline)
.append(payloadHash)
.result()

val canonicalRequestHash = sha256HexDigest(canonicalRequest)
val signatureKey = getSignatureKey(
credentials.secretAccessKey,
timestamp.conciseDate,
region.value,
endpointPrefix
)
val stringToSign = List[String](
algorithm,
timestamp.conciseDateTime,
credentialsScope,
canonicalRequestHash
).mkString(newline)
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}
(request: Request[F]) =>
preSign(request).flatMap { case (payloadHash, preparedRequest) =>
Copy link
Contributor Author

@Baccata Baccata Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is best compared side-by-side (split) in github, with whitespace hidden. This is one of the important changes

val awsHeadersF = (timestamp, credentials, region).mapN { case (timestamp, credentials, region) =>
val credentialsScope = s"${timestamp.conciseDate}/$region/$endpointPrefix/aws4_request"
val queryParams: Vector[(String, String)] =
request.uri.query.toVector.sorted.map { case (k, v) => k -> v.getOrElse("") }
val canonicalQueryString =
if (queryParams.isEmpty) ""
else
queryParams
.map { case (k, v) =>
URLEncoder.encode(k, StandardCharsets.UTF_8.name()) + "=" + URLEncoder.encode(v, StandardCharsets.UTF_8.name())
}
.mkString("&")

// // !\ Important: these must remain in the same order
val baseHeadersList = List(
`Content-Type` -> preparedRequest.contentType.map(contentType.value(_)).orNull,
`Host` -> preparedRequest.uri.host.map(_.renderString).orNull,
`X-Amz-Content-SHA256` -> preparedRequest.headers.get(`X-Amz-Content-SHA256`).map(_.head.value).orNull,
`X-Amz-Date` -> timestamp.conciseDateTime,
`X-Amz-Security-Token` -> credentials.sessionToken.orNull,
`X-Amz-Target` -> (serviceName + "." + operationName)
).filterNot(_._2 == null)

val canonicalHeadersString = baseHeadersList
.map { case (key, value) =>
key.toString.toLowerCase + ":" + value.trim
}
.mkString(newline)
lazy val signedHeadersString = baseHeadersList.map(_._1).map(_.toString.toLowerCase()).mkString(";")

val pathString = preparedRequest.uri.path.toAbsolute.renderString
val canonicalRequest = new StringBuilder()
.append(request.method.name.toUpperCase())
.append(newline)
.append(pathString)
.append(newline)
.append(canonicalQueryString)
.append(newline)
.append(canonicalHeadersString)
.append(newline)
.append(newline)
.append(signedHeadersString)
.append(newline)
.append(payloadHash)
.result()

val canonicalRequestHash = sha256HexDigest(canonicalRequest)
val signatureKey = getSignatureKey(
credentials.secretAccessKey,
timestamp.conciseDate,
region.value,
endpointPrefix
)
val stringToSign = List[String](
algorithm,
timestamp.conciseDateTime,
credentialsScope,
canonicalRequestHash
).mkString(newline)
val signature = toHexString(hmacSha256(stringToSign, signatureKey))
val authHeaderValue = s"${algorithm} Credential=${credentials.accessKeyId}/$credentialsScope, SignedHeaders=$signedHeadersString, Signature=$signature"
val authHeader = Headers("Authorization" -> authHeaderValue)
val baseHeaders = Headers(baseHeadersList.map { case (k, v) => Header.Raw(k, v) })
authHeader ++ baseHeaders
}

awsHeadersF.map { headers =>
request.transformHeaders(_ ++ headers)
awsHeadersF.map { headers =>
preparedRequest.transformHeaders(_ ++ headers)
}
}
}
}

private val newline = System.lineSeparator()
Expand All @@ -187,5 +195,31 @@ private[aws] object AwsSigning {
private val `X-Amz-Security-Token` = CIString("X-Amz-Security-Token")
private val `X-Amz-Target` = CIString("X-Amz-Target")
private val algorithm = "AWS4-HMAC-SHA256"
private val `X-Amz-Content-SHA256` = CIString("X-Amz-Content-SHA256")

private sealed trait PreSigner[F[_]] {
def apply(request: Request[F]): F[(String, Request[F])]
}
private object PreSigner {
class Standard[F[_]](implicit F: Concurrent[F]) extends PreSigner[F] {
def apply(request: Request[F]): F[(String, Request[F])] = {
request.body.compile.to(Chunk).map { inMemoryBody =>
val payloadHash = sha256HexDigest(inMemoryBody.toArray)
val newRequest = request.withBodyStream(fs2.Stream.chunk(inMemoryBody))
(payloadHash, newRequest)
}
}
}

class S3InMemorySigned[F[_]](implicit F: Concurrent[F]) extends PreSigner[F] {
def apply(request: Request[F]): F[(String, Request[F])] = {
request.body.compile.to(Chunk).map { inMemoryBody =>
val payloadHash = sha256HexDigest(inMemoryBody.toArray)
val newRequest = request.withBodyStream(fs2.Stream.chunk(inMemoryBody)).transformHeaders(_.put(Header.Raw(`X-Amz-Content-SHA256`, payloadHash)))
(payloadHash, newRequest)
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import java.time.Clock
import java.time.ZoneId
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._
import software.amazon.awssdk.auth.signer.AwsS3V4Signer
import software.amazon.awssdk.auth.signer.params.AwsS3V4SignerParams

/**
* This suite verifies our implementation of the AWS signature algorithm against
Expand Down Expand Up @@ -92,9 +94,9 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
val genAwsRequest = for {
httpMethod <- Gen.oneOf(SdkHttpMethod.values().toList)
host <- Gen.identifier
path <- Gen.listOf(Gen.identifier).map(_.mkString("/"))
path <- Gen.listOfN(3, Gen.identifier).map(_.mkString("/"))
content <- Gen.asciiStr
queryParams <- Gen.listOf(Gen.zip(Gen.identifier, Gen.alphaNumStr))
queryParams <- Gen.listOfN(3, Gen.zip(Gen.identifier, Gen.alphaNumStr))
} yield {
val builder = SdkHttpFullRequest
.builder()
Expand All @@ -112,7 +114,7 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
}

val gen: Gen[TestInput] = for {
serviceName <- Gen.identifier
serviceName <- Gen.oneOf(Gen.const("AmazonS3"), Gen.identifier)
operationName <- Gen.identifier
timestamp <- Gen.chooseNum(0L, 4102444800L).map(Timestamp.fromEpochSecond)
accessKeyId <- Gen.identifier
Expand Down Expand Up @@ -154,23 +156,53 @@ object AwsSignatureTest extends SimpleIOSuite with Checkers {
}

val region = Region.of(smithy4sRegion.value)
val signedAwsRequest = if (testInput.serviceName == AwsSigning.S3) {

val params = Aws4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.signingName(serviceName)
.build()

val awsSigner = Aws4Signer.create()
// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.build()
val signedAwsRequest = awsSigner.sign(amendedAwsRequest, params)
val signerParams = AwsS3V4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.enablePayloadSigning(true)
.signingName(serviceName)
.build()

// yes, this is an S3-specific signer.
val awsSigner = AwsS3V4Signer.create()

// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
//
// The hardcoded "required" header value is understood by the S3 signer as a signal that the `X-Amz-Content-SHA256` header
// should be replaced by the hash of the request payload, and that the same hash should be used in the signature.
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.appendHeader(
"X-Amz-Content-SHA256",
"required"
) // this is a magic addition that is understood by the S3 signer
.build()

awsSigner.sign(amendedAwsRequest, signerParams)
} else {
val params = Aws4SignerParams
.builder()
.awsCredentials(creds)
.signingRegion(region)
.signingClockOverride(fixedClock)
.signingName(serviceName)
.build()

val awsSigner = Aws4Signer.create()
// Amending the AWS Request to force the AMZ target as it's added automatically
// by our implementation
val amendedAwsRequest = awsRequest
.toBuilder()
.appendHeader("X-Amz-Target", serviceName + "." + operationName)
.build()
awsSigner.sign(amendedAwsRequest, params)
}

val smithy4sSigner = AwsSigning.signingFunction[IO](
serviceName,
Expand Down
3 changes: 2 additions & 1 deletion modules/core/src/smithy4s/Blob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ sealed trait Blob {

final def ++(other: Blob) = concat(other)

override def equals(other: Any): Boolean =
override def equals(other: Any): Boolean = {
other match {
case otherBlob: Blob => sameBytesAs(otherBlob)
case _ => false
}
}

override def hashCode(): Int = {
import util.hashing.MurmurHash3
Expand Down
Loading