Skip to content

Commit

Permalink
Update token validation for aws-sdk-go v2
Browse files Browse the repository at this point in the history
We pass the full request details, it's less dependent on client
versions.
  • Loading branch information
justinsb authored and rifelpet committed Apr 27, 2024
1 parent a9cd5ef commit f97ecaf
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 85 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ require (
github.com/Masterminds/sprig/v3 v3.2.3
github.com/apparentlymart/go-cidr v1.1.0
github.com/aws/amazon-ec2-instance-selector/v2 v2.4.2-0.20231216170552-14d4dfcbaadf
github.com/aws/aws-sdk-go v1.51.29
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/config v1.27.11
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
Expand Down Expand Up @@ -111,6 +110,7 @@ require (
github.com/Microsoft/hcsshim v0.11.4 // indirect
github.com/armon/go-metrics v0.4.1 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aws/aws-sdk-go v1.51.29 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
Expand Down
25 changes: 22 additions & 3 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
Expand Down Expand Up @@ -63,18 +64,36 @@ func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenti
}, nil
}

type awsV2Token struct {
URL string `json:"url"`
Method string `json:"method"`
SignedHeader http.Header `json:"headers"`
}

func (a *awsAuthenticator) CreateToken(body []byte) (string, error) {
sha := sha256.Sum256(body)

presignClient := sts.NewPresignClient(a.sts)

// Ensure the signature is only valid for this particular body content.
stsRequest, _ := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
stsRequest, err := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) {
o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:])))
})
})
if err != nil {
return "", fmt.Errorf("building AWS STS presigned request: %w", err)
}

awsV2Token := &awsV2Token{
URL: stsRequest.URL,
Method: stsRequest.Method,
SignedHeader: stsRequest.SignedHeader,
}
token, err := json.Marshal(awsV2Token)
if err != nil {
return "", fmt.Errorf("converting token to json: %w", err)
}

headers, _ := json.Marshal(stsRequest.SignedHeader)
return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(headers), nil
return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(token), nil
}
141 changes: 98 additions & 43 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,23 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"testing"

// "github.com/aws/aws-sdk-go-v2/aws"
// "github.com/aws/aws-sdk-go-v2/credentials"
// "github.com/aws/aws-sdk-go-v2/service/sts"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

func TestAWSPresign(t *testing.T) {
// mockSTSServer := &mockSTSServer{t: t}
// awsConfig := aws.Config{}
// awsConfig.Region = "us-east-1"
// awsConfig.Credentials = credentials.NewStaticCredentialsProvider("accesskey", "secretkey", "")
// awsConfig.HTTPClient = mockSTSServer
// sts := sts.NewFromConfig(awsConfig)

mySession := session.Must(session.NewSession())
mySession.Config.Credentials = credentials.NewStaticCredentials("accesskey", "secretkey", "")
sts := sts.New(mySession)
mySession.Config.HTTPClient = &http.Client{Transport: &mockHTTPTransport{}}
mockSTSServer := &mockHTTPClient{t: t}
awsConfig := aws.Config{}
awsConfig.Region = "us-east-1"
awsConfig.Credentials = credentials.NewStaticCredentialsProvider("fakeaccesskey", "fakesecretkey", "")
awsConfig.HTTPClient = mockSTSServer
sts := sts.NewFromConfig(awsConfig)

a := &awsAuthenticator{
sts: sts,
}
Expand All @@ -68,24 +61,29 @@ func TestAWSPresign(t *testing.T) {
if err != nil {
t.Fatalf("decoding token as base64: %v", err)
}
headers := make(map[string][]string)
if err := json.Unmarshal([]byte(data), &headers); err != nil {
decoded := &awsV2Token{}
if err := json.Unmarshal([]byte(data), &decoded); err != nil {
t.Fatalf("decoding token as json: %v", err)
}

t.Logf("headers: %+v", headers)
t.Logf("decoded: %+v", decoded)

amzSignature := ""
amzSignedHeaders := ""
amzDate := ""
amzAlgorithm := ""
amzCredential := ""

authorization := ""
for header, values := range headers {
for header, values := range decoded.SignedHeader {
got := strings.Join(values, "||")
switch header {
case "User-Agent":
// Ignore
// TODO: Should we (can we) override the useragent?
case "X-Amz-Date":
if len(got) < 10 {
t.Errorf("expected %q header of at least 10 characters, got %q", header, got)
}
amzDate = got

case "Content-Length":
if want := "43"; got != want {
t.Errorf("unexpected %q header: got %q, want %q", header, got, want)
Expand All @@ -95,6 +93,11 @@ func TestAWSPresign(t *testing.T) {
t.Errorf("unexpected %q header: got %q, want %q", header, got, want)
}

case "Host":
if want := "sts.us-east-1.amazonaws.com"; got != want {
t.Errorf("unexpected %q header: got %q, want %q", header, got, want)
}

case "X-Kops-Request-Sha":
if want := bodyHashBase64; got != want {
t.Errorf("unexpected %q header: got %q, want %q", header, got, want)
Expand All @@ -103,34 +106,86 @@ func TestAWSPresign(t *testing.T) {
// Validated more deeply below
authorization = got
default:
t.Errorf("unexpected header %q", header)
t.Errorf("unexpected header %q: %q", header, got)
}
}

if !strings.HasPrefix(authorization, "AWS4-HMAC-SHA256 ") {
t.Errorf("unexpected authorization prefix, got %q", authorization)
}
for _, token := range strings.Split(strings.TrimPrefix(authorization, "AWS4-HMAC-SHA256 "), ", ") {
kv := strings.SplitN(token, "=", 2)
got := kv[1]
switch kv[0] {
case "Signature":
if len(got) < 10 {
t.Errorf("expected %q Authorization value of at least 10 characters, got %q", kv[0], got)
// TODO: This is only aws-sdk-go V1
if authorization != "" {
if !strings.HasPrefix(authorization, "AWS4-HMAC-SHA256 ") {
t.Errorf("unexpected authorization prefix, got %q", authorization)
}

for _, token := range strings.Split(strings.TrimPrefix(authorization, "AWS4-HMAC-SHA256 "), ", ") {
kv := strings.SplitN(token, "=", 2)
if len(kv) == 1 {
t.Errorf("invalid token %q in authorization header", token)
continue
}
case "Credential":
if len(got) < 10 {
t.Errorf("expected %q Authorization value of at least 10 characters, got %q", kv[0], got)
got := kv[1]
switch kv[0] {
case "Signature":
amzSignature = got
case "Credential":
amzCredential = got
case "SignedHeaders":
amzSignedHeaders = got
default:
t.Errorf("unknown token %q in authorization header", token)
}
case "SignedHeaders":
if want := "content-length;content-type;host;x-amz-date;x-kops-request-sha"; got != want {
t.Errorf("unexpected %q Authorization value: got %q, want %q", kv[0], got, want)
}
}

u, err := url.Parse(decoded.URL)
if err != nil {
t.Errorf("error parsing url %q: %v", decoded.URL, err)
}
for k, values := range u.Query() {
got := strings.Join(values, "||")

switch k {
case "Action":
if want := "GetCallerIdentity"; got != want {
t.Errorf("unexpected %q query param: got %q, want %q", k, got, want)
}
case "Version":
if want := "2011-06-15"; got != want {
t.Errorf("unexpected %q query param: got %q, want %q", k, got, want)
}
case "X-Amz-Date":
amzDate = k
case "X-Amz-Signature":
amzSignature = k
case "X-Amz-Credential":
amzCredential = got
case "X-Amz-SignedHeaders":
amzSignedHeaders = got
case "X-Amz-Algorithm":
amzAlgorithm = got
default:
t.Errorf("unknown token %q in authorization header", token)
t.Errorf("unknown token %q=%q in query", k, got)
}
}

if len(amzCredential) < 10 {
t.Errorf("expected amzCredential value of at least 10 characters, got %q", amzCredential)
}

if len(amzDate) < 10 {
t.Errorf("expected amz-date of at least 10 characters, got %q", amzDate)
}

if len(amzSignature) < 10 {
t.Errorf("expected amzSignature value of at least 10 characters, got %q", amzSignature)
}

if want := "AWS4-HMAC-SHA256"; amzAlgorithm != want {
t.Errorf("unexpected amzAlgorithm: got %q, want %q", amzAlgorithm, want)
}

if want := "host;x-kops-request-sha"; amzSignedHeaders != want {
t.Errorf("unexpected amzSignedHeaders: got %q, want %q", amzSignedHeaders, want)
}
}

type mockHTTPClient struct {
Expand Down
Loading

0 comments on commit f97ecaf

Please sign in to comment.