Skip to content

Commit

Permalink
Merge customizations for DSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
AWS SDK for Go v2 automation user committed Dec 3, 2024
1 parent 3418b74 commit cfd99f8
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 13 deletions.
8 changes: 8 additions & 0 deletions .changelog/c308474320474f91bb3dc286f554d575.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "c3084743-2047-4f91-bb3d-c286f554d575",
"type": "feature",
"description": "feat: Add Xanadu Auth Token Generator",
"modules": [
"feature/rds/auth"
]
}
82 changes: 71 additions & 11 deletions feature/rds/auth/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)

const (
signingID = "rds-db"
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
rdsAuthTokenID = "rds-db"
rdsClusterTokenID = "dsql"
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
userAction = "DbConnect"
adminUserAction = "DbConnectAdmin"
)

// BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken
type BuildAuthTokenOptions struct{}
type BuildAuthTokenOptions struct {
ExpiresIn time.Duration
}

// BuildAuthToken will return an authorization token used as the password for a DB
// connection.
//
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
// * endpoint - Endpoint consists of the hostname and port needed to connect to the DB. <host>:<port>
// * region - Region is the location of where the DB is
// * dbUser - User account within the database to sign in with
// * creds - Credentials to be signed with
Expand All @@ -50,12 +57,64 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

values := url.Values{
"Action": []string{"connect"},
"DBUser": []string{dbUser},
}

return generateAuthToken(ctx, endpoint, region, values, rdsAuthTokenID, creds, optFns...)
}

// GenerateDbConnectAuthToken will return an authorization token as the password for a
// DB connection.
//
// This is the regular user variant, see [GenerateDBConnectSuperUserAuthToken] for the superuser variant
//
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
// * region - Region is the location of where the DB is
// * creds - Credentials to be signed with
func GenerateDbConnectAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
values := url.Values{
"Action": []string{userAction},
}
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
}

// GenerateDBConnectSuperUserAuthToken will return an authorization token as the password for a
// DB connection.
//
// This is the superuser user variant, see [GenerateDBConnectSuperUserAuthToken] for the regular user variant
//
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
// * region - Region is the location of where the DB is
// * creds - Credentials to be signed with
func GenerateDBConnectSuperUserAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
values := url.Values{
"Action": []string{adminUserAction},
}
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
}

// All generate token functions are presigned URLs behind the scenes with the scheme stripped.
// This function abstracts generating this for all use cases
func generateAuthToken(ctx context.Context, endpoint, region string, values url.Values, signingID string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
if len(region) == 0 {
return "", fmt.Errorf("region is required")
}
if len(endpoint) == 0 {
return "", fmt.Errorf("endpoint is required")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
fn(&o)
}

if o.ExpiresIn == 0 {
o.ExpiresIn = 15 * time.Minute
}

if creds == nil {
return "", fmt.Errorf("credetials provider must not ne nil")
}
Expand All @@ -69,24 +128,25 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds
if err != nil {
return "", err
}
values := req.URL.Query()
values.Set("Action", "connect")
values.Set("DBUser", dbUser)
req.URL.RawQuery = values.Encode()

signer := v4.NewSigner()

credentials, err := creds.Retrieve(ctx)
if err != nil {
return "", err
}

// Expire Time: 15 minute
expires := o.ExpiresIn
// if creds expire before expiresIn, set that as the expiration time
if credentials.CanExpire && !credentials.Expires.IsZero() {
credsExpireIn := credentials.Expires.Sub(sdk.NowTime())
expires = min(o.ExpiresIn, credsExpireIn)
}
query := req.URL.Query()
query.Set("X-Amz-Expires", "900")
query.Set("X-Amz-Expires", strconv.Itoa(int(expires.Seconds())))
req.URL.RawQuery = query.Encode()

signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC())
signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, sdk.NowTime().UTC())
if err != nil {
return "", err
}
Expand Down
148 changes: 146 additions & 2 deletions feature/rds/auth/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package auth_test

import (
"context"
"net/url"
"regexp"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)

func TestBuildAuthToken(t *testing.T) {
Expand Down Expand Up @@ -67,14 +70,155 @@ func TestBuildAuthToken(t *testing.T) {
}
}

type dbAuthTestCase struct {
endpoint string
region string
expires time.Duration
credsExpireIn time.Duration
expectedHost string
expectedQueryParams []string
expectedError string
}

type tokenGenFunc func(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error)

func TestGenerateDbConnectAuthToken(t *testing.T) {
cases := map[string]dbAuthTestCase{
"no region": {
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedError: "no region",
},
"no endpoint": {
region: "us-west-2",
expectedError: "port",
},
"endpoint with scheme": {
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint without scheme": {
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint without port": {
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-east-1",
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com",
expectedQueryParams: []string{"Action=DbConnect"},
},
"endpoint with region and expires": {
endpoint: "peccy.dsql.us-east-1.on.aws",
region: "us-east-1",
expires: time.Second * 450,
expectedHost: "peccy.dsql.us-east-1.on.aws",
expectedQueryParams: []string{
"Action=DbConnect",
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
"X-Amz-Date=20240827T000000Z",
"X-Amz-Expires=450"},
},
"pick credential expires when less than expires": {
endpoint: "peccy.dsql.us-east-1.on.aws",
region: "us-east-1",
credsExpireIn: time.Second * 100,
expires: time.Second * 450,
expectedHost: "peccy.dsql.us-east-1.on.aws",
expectedQueryParams: []string{
"Action=DbConnect",
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
"X-Amz-Date=20240827T000000Z",
"X-Amz-Expires=100"},
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "akid", SecretKey: "secret", expiresIn: c.credsExpireIn}
defer withTempGlobalTime(time.Date(2024, time.August, 27, 0, 0, 0, 0, time.UTC))()
optFns := func(options *auth.BuildAuthTokenOptions) {}
if c.expires != 0 {
optFns = func(options *auth.BuildAuthTokenOptions) {
options.ExpiresIn = c.expires
}
}
verifyTestCase(auth.GenerateDbConnectAuthToken, c, creds, optFns, t)

// Update the test case to use Superuser variant
updated := []string{}
for _, part := range c.expectedQueryParams {
if part == "Action=DbConnect" {
part = "Action=DbConnectAdmin"
}
updated = append(updated, part)
}
c.expectedQueryParams = updated

verifyTestCase(auth.GenerateDBConnectSuperUserAuthToken, c, creds, optFns, t)
}
}

func verifyTestCase(f tokenGenFunc, c dbAuthTestCase, creds aws.CredentialsProvider, optFns func(options *auth.BuildAuthTokenOptions), t *testing.T) {
token, err := f(context.Background(), c.endpoint, c.region, creds, optFns)
isErrorExpected := len(c.expectedError) > 0
if err != nil && !isErrorExpected {
t.Fatalf("expect no err, got: %v", err)
} else if err == nil && isErrorExpected {
t.Fatalf("Expected error %v got none", c.expectedError)
}
// adding a scheme so we can parse it back as a URL. This is because comparing
// just direct string comparison was failing since "Action=DbConnect" is a substring or
// "Action=DBConnectSuperuser"
parsed, err := url.Parse("http://" + token)
if err != nil {
t.Fatalf("Couldn't parse the token %v to URL after adding a scheme, got: %v", token, err)
}
if parsed.Host != c.expectedHost {
t.Errorf("expect host %v, got %v", c.expectedHost, parsed.Host)
}

q := parsed.Query()
queryValuePair := map[string]any{}
for k, v := range q {
pair := k + "=" + v[0]
queryValuePair[pair] = struct{}{}
}

for _, part := range c.expectedQueryParams {
if _, ok := queryValuePair[part]; !ok {
t.Errorf("expect part %s to be present at token %s", part, token)
}
}
if token != "" && c.expires == 0 {
if !strings.Contains(token, "X-Amz-Expires=900") {
t.Errorf("expect token to contain default X-Amz-Expires value of 900, got %v", token)
}
}
}

type staticCredentials struct {
AccessKey, SecretKey, Session string
expiresIn time.Duration
}

func (s *staticCredentials) Retrieve(ctx context.Context) (aws.Credentials, error) {
return aws.Credentials{
c := aws.Credentials{
AccessKeyID: s.AccessKey,
SecretAccessKey: s.SecretKey,
SessionToken: s.Session,
}, nil
}
if s.expiresIn != 0 {
c.CanExpire = true
c.Expires = sdk.NowTime().Add(s.expiresIn)
}
return c, nil
}

func withTempGlobalTime(t time.Time) func() {
sdk.NowTime = func() time.Time { return t }
return func() { sdk.NowTime = time.Now }
}

0 comments on commit cfd99f8

Please sign in to comment.