diff --git a/.changelog/c308474320474f91bb3dc286f554d575.json b/.changelog/c308474320474f91bb3dc286f554d575.json new file mode 100644 index 00000000000..065a2b5464f --- /dev/null +++ b/.changelog/c308474320474f91bb3dc286f554d575.json @@ -0,0 +1,8 @@ +{ + "id": "c3084743-2047-4f91-bb3d-c286f554d575", + "type": "feature", + "description": "feat: Add Xanadu Auth Token Generator", + "modules": [ + "feature/rds/auth" + ] +} \ No newline at end of file diff --git a/feature/rds/auth/connect.go b/feature/rds/auth/connect.go index 9a1406e7ed3..d6cb5042380 100644 --- a/feature/rds/auth/connect.go +++ b/feature/rds/auth/connect.go @@ -4,26 +4,33 @@ import ( "context" "fmt" "net/http" + "net/url" "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/internal/sdk" ) const ( - signingID = "rds-db" - emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + rdsAuthTokenID = "rds-db" + rdsClusterTokenID = "dsql" + emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + userAction = "DbConnect" + adminUserAction = "DbConnectAdmin" ) // BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken -type BuildAuthTokenOptions struct{} +type BuildAuthTokenOptions struct { + ExpiresIn time.Duration +} // BuildAuthToken will return an authorization token used as the password for a DB // connection. // -// * endpoint - Endpoint consists of the port needed to connect to the DB. : +// * endpoint - Endpoint consists of the hostname and port needed to connect to the DB. : // * region - Region is the location of where the DB is // * dbUser - User account within the database to sign in with // * creds - Credentials to be signed with @@ -50,12 +57,64 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid") } + values := url.Values{ + "Action": []string{"connect"}, + "DBUser": []string{dbUser}, + } + + return generateAuthToken(ctx, endpoint, region, values, rdsAuthTokenID, creds, optFns...) +} + +// GenerateDbConnectAuthToken will return an authorization token as the password for a +// DB connection. +// +// This is the regular user variant, see [GenerateDBConnectSuperUserAuthToken] for the superuser variant +// +// * endpoint - Endpoint is the hostname and optional port to connect to the DB +// * region - Region is the location of where the DB is +// * creds - Credentials to be signed with +func GenerateDbConnectAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) { + values := url.Values{ + "Action": []string{userAction}, + } + return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...) +} + +// GenerateDBConnectSuperUserAuthToken will return an authorization token as the password for a +// DB connection. +// +// This is the superuser user variant, see [GenerateDBConnectSuperUserAuthToken] for the regular user variant +// +// * endpoint - Endpoint is the hostname and optional port to connect to the DB +// * region - Region is the location of where the DB is +// * creds - Credentials to be signed with +func GenerateDBConnectSuperUserAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) { + values := url.Values{ + "Action": []string{adminUserAction}, + } + return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...) +} + +// All generate token functions are presigned URLs behind the scenes with the scheme stripped. +// This function abstracts generating this for all use cases +func generateAuthToken(ctx context.Context, endpoint, region string, values url.Values, signingID string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) { + if len(region) == 0 { + return "", fmt.Errorf("region is required") + } + if len(endpoint) == 0 { + return "", fmt.Errorf("endpoint is required") + } + o := BuildAuthTokenOptions{} for _, fn := range optFns { fn(&o) } + if o.ExpiresIn == 0 { + o.ExpiresIn = 15 * time.Minute + } + if creds == nil { return "", fmt.Errorf("credetials provider must not ne nil") } @@ -69,11 +128,7 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds if err != nil { return "", err } - values := req.URL.Query() - values.Set("Action", "connect") - values.Set("DBUser", dbUser) req.URL.RawQuery = values.Encode() - signer := v4.NewSigner() credentials, err := creds.Retrieve(ctx) @@ -81,12 +136,17 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds return "", err } - // Expire Time: 15 minute + expires := o.ExpiresIn + // if creds expire before expiresIn, set that as the expiration time + if credentials.CanExpire && !credentials.Expires.IsZero() { + credsExpireIn := credentials.Expires.Sub(sdk.NowTime()) + expires = min(o.ExpiresIn, credsExpireIn) + } query := req.URL.Query() - query.Set("X-Amz-Expires", "900") + query.Set("X-Amz-Expires", strconv.Itoa(int(expires.Seconds()))) req.URL.RawQuery = query.Encode() - signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC()) + signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, sdk.NowTime().UTC()) if err != nil { return "", err } diff --git a/feature/rds/auth/connect_test.go b/feature/rds/auth/connect_test.go index 7ccb2332f86..cfa0185274a 100644 --- a/feature/rds/auth/connect_test.go +++ b/feature/rds/auth/connect_test.go @@ -2,12 +2,15 @@ package auth_test import ( "context" + "net/url" "regexp" "strings" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go-v2/internal/sdk" ) func TestBuildAuthToken(t *testing.T) { @@ -67,14 +70,155 @@ func TestBuildAuthToken(t *testing.T) { } } +type dbAuthTestCase struct { + endpoint string + region string + expires time.Duration + credsExpireIn time.Duration + expectedHost string + expectedQueryParams []string + expectedError string +} + +type tokenGenFunc func(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error) + +func TestGenerateDbConnectAuthToken(t *testing.T) { + cases := map[string]dbAuthTestCase{ + "no region": { + endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306", + expectedError: "no region", + }, + "no endpoint": { + region: "us-west-2", + expectedError: "port", + }, + "endpoint with scheme": { + endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306", + region: "us-east-1", + expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306", + expectedQueryParams: []string{"Action=DbConnect"}, + }, + "endpoint without scheme": { + endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306", + region: "us-east-1", + expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306", + expectedQueryParams: []string{"Action=DbConnect"}, + }, + "endpoint without port": { + endpoint: "prod-instance.us-east-1.rds.amazonaws.com", + region: "us-east-1", + expectedHost: "prod-instance.us-east-1.rds.amazonaws.com", + expectedQueryParams: []string{"Action=DbConnect"}, + }, + "endpoint with region and expires": { + endpoint: "peccy.dsql.us-east-1.on.aws", + region: "us-east-1", + expires: time.Second * 450, + expectedHost: "peccy.dsql.us-east-1.on.aws", + expectedQueryParams: []string{ + "Action=DbConnect", + "X-Amz-Algorithm=AWS4-HMAC-SHA256", + "X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request", + "X-Amz-Date=20240827T000000Z", + "X-Amz-Expires=450"}, + }, + "pick credential expires when less than expires": { + endpoint: "peccy.dsql.us-east-1.on.aws", + region: "us-east-1", + credsExpireIn: time.Second * 100, + expires: time.Second * 450, + expectedHost: "peccy.dsql.us-east-1.on.aws", + expectedQueryParams: []string{ + "Action=DbConnect", + "X-Amz-Algorithm=AWS4-HMAC-SHA256", + "X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request", + "X-Amz-Date=20240827T000000Z", + "X-Amz-Expires=100"}, + }, + } + + for _, c := range cases { + creds := &staticCredentials{AccessKey: "akid", SecretKey: "secret", expiresIn: c.credsExpireIn} + defer withTempGlobalTime(time.Date(2024, time.August, 27, 0, 0, 0, 0, time.UTC))() + optFns := func(options *auth.BuildAuthTokenOptions) {} + if c.expires != 0 { + optFns = func(options *auth.BuildAuthTokenOptions) { + options.ExpiresIn = c.expires + } + } + verifyTestCase(auth.GenerateDbConnectAuthToken, c, creds, optFns, t) + + // Update the test case to use Superuser variant + updated := []string{} + for _, part := range c.expectedQueryParams { + if part == "Action=DbConnect" { + part = "Action=DbConnectAdmin" + } + updated = append(updated, part) + } + c.expectedQueryParams = updated + + verifyTestCase(auth.GenerateDBConnectSuperUserAuthToken, c, creds, optFns, t) + } +} + +func verifyTestCase(f tokenGenFunc, c dbAuthTestCase, creds aws.CredentialsProvider, optFns func(options *auth.BuildAuthTokenOptions), t *testing.T) { + token, err := f(context.Background(), c.endpoint, c.region, creds, optFns) + isErrorExpected := len(c.expectedError) > 0 + if err != nil && !isErrorExpected { + t.Fatalf("expect no err, got: %v", err) + } else if err == nil && isErrorExpected { + t.Fatalf("Expected error %v got none", c.expectedError) + } + // adding a scheme so we can parse it back as a URL. This is because comparing + // just direct string comparison was failing since "Action=DbConnect" is a substring or + // "Action=DBConnectSuperuser" + parsed, err := url.Parse("http://" + token) + if err != nil { + t.Fatalf("Couldn't parse the token %v to URL after adding a scheme, got: %v", token, err) + } + if parsed.Host != c.expectedHost { + t.Errorf("expect host %v, got %v", c.expectedHost, parsed.Host) + } + + q := parsed.Query() + queryValuePair := map[string]any{} + for k, v := range q { + pair := k + "=" + v[0] + queryValuePair[pair] = struct{}{} + } + + for _, part := range c.expectedQueryParams { + if _, ok := queryValuePair[part]; !ok { + t.Errorf("expect part %s to be present at token %s", part, token) + } + } + if token != "" && c.expires == 0 { + if !strings.Contains(token, "X-Amz-Expires=900") { + t.Errorf("expect token to contain default X-Amz-Expires value of 900, got %v", token) + } + } +} + type staticCredentials struct { AccessKey, SecretKey, Session string + expiresIn time.Duration } func (s *staticCredentials) Retrieve(ctx context.Context) (aws.Credentials, error) { - return aws.Credentials{ + c := aws.Credentials{ AccessKeyID: s.AccessKey, SecretAccessKey: s.SecretKey, SessionToken: s.Session, - }, nil + } + if s.expiresIn != 0 { + c.CanExpire = true + c.Expires = sdk.NowTime().Add(s.expiresIn) + } + return c, nil +} + +func withTempGlobalTime(t time.Time) func() { + sdk.NowTime = func() time.Time { return t } + return func() { sdk.NowTime = time.Now } }