Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] aws oidc skip aurora clusters without instances #47605

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/integration/integrationv1/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func (s *AWSOIDCService) ListDatabases(ctx context.Context, req *integrationpb.L
return nil, trace.Wrap(err)
}

listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, awsoidc.ListDatabasesRequest{
listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, s.logger, awsoidc.ListDatabasesRequest{
Region: req.Region,
RDSType: req.RdsType,
Engines: req.Engines,
Expand Down
33 changes: 18 additions & 15 deletions lib/integrations/awsoidc/listdatabases.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package awsoidc

import (
"context"
"log/slog"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/rds"
Expand Down Expand Up @@ -116,14 +117,14 @@ var listDatabasesPageSize int32 = 50
// https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html
// https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html
// It returns a list of Databases and an optional NextToken that can be used to fetch the next page
func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func ListDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

all := &ListDatabasesResponse{}
for {
res, err := listDatabases(ctx, clt, req)
res, err := listDatabases(ctx, clt, log, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -140,7 +141,7 @@ func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas
}
}

func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func listDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
// Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html
if req.RDSType == rdsTypeInstance {
ret, err := listDBInstances(ctx, clt, req)
Expand All @@ -151,7 +152,7 @@ func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas
}

// Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html
ret, err := listDBClusters(ctx, clt, req)
ret, err := listDBClusters(ctx, clt, log, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -199,7 +200,7 @@ func listDBInstances(ctx context.Context, clt ListDatabasesClient, req ListDatab
return ret, nil
}

func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func listDBClusters(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
describeDBClusterInput := &rds.DescribeDBClustersInput{
Filters: []rdsTypes.Filter{
{Name: &filterEngine, Values: req.Engines},
Expand Down Expand Up @@ -231,16 +232,23 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba
// To get this value, a member of the cluster is fetched and its Network Information is used to
// populate the RDS Cluster information.
// All the members have the same network information, so picking one at random should not matter.
clusterInstance, err := fetchSingleRDSDBInstance(ctx, clt, req, aws.ToString(db.DBClusterIdentifier))
instances, err := fetchRDSClusterInstances(ctx, clt, req, aws.ToString(db.DBClusterIdentifier))
if err != nil {
return nil, trace.Wrap(err)
}
if len(instances) == 0 {
log.InfoContext(ctx, "Skipping RDS cluster because it has no instances",
"cluster", aws.ToString(db.DBClusterIdentifier),
)
continue
}
instance := &instances[0]

if req.VpcId != "" && !subnetGroupIsInVPC(clusterInstance.DBSubnetGroup, req.VpcId) {
if req.VpcId != "" && !subnetGroupIsInVPC(instance.DBSubnetGroup, req.VpcId) {
continue
}

awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, clusterInstance)
awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, instance)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -251,7 +259,7 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba
return ret, nil
}

func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) (*rdsTypes.DBInstance, error) {
func fetchRDSClusterInstances(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) ([]rdsTypes.DBInstance, error) {
describeDBInstanceInput := &rds.DescribeDBInstancesInput{
Filters: []rdsTypes.Filter{
{Name: &filterDBClusterID, Values: []string{clusterID}},
Expand All @@ -262,12 +270,7 @@ func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req
if err != nil {
return nil, trace.Wrap(err)
}

if len(rdsDBs.DBInstances) == 0 {
return nil, trace.BadParameter("database cluster %s has no instance", clusterID)
}

return &rdsDBs.DBInstances[0], nil
return rdsDBs.DBInstances, nil
}

// subnetGroupIsInVPC is a simple helper to check if a db subnet group is in
Expand Down
88 changes: 73 additions & 15 deletions lib/integrations/awsoidc/listdatabases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
)

func stringPointer(s string) *string {
Expand Down Expand Up @@ -140,6 +141,7 @@ func TestListDatabases(t *testing.T) {

t.Run("without vpc filter", func(t *testing.T) {
t.Parallel()
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
// First page must return pageSize number of DBs
req := ListDatabasesRequest{
Region: "us-east-1",
Expand All @@ -149,25 +151,26 @@ func TestListDatabases(t *testing.T) {
NextToken: "",
}
for i := 0; i < totalDBs/int(listDatabasesPageSize); i++ {
resp, err := ListDatabases(ctx, mockListClient, req)
resp, err := ListDatabases(ctx, mockListClient, logger, req)
require.NoError(t, err)
require.Len(t, resp.Databases, int(listDatabasesPageSize))
require.NotEmpty(t, resp.NextToken)
req.NextToken = resp.NextToken
}
// Last page must return remaining databases and an empty token.
resp, err := ListDatabases(ctx, mockListClient, req)
resp, err := ListDatabases(ctx, mockListClient, logger, req)
require.NoError(t, err)
require.Len(t, resp.Databases, totalDBs%int(listDatabasesPageSize))
require.Empty(t, resp.NextToken)
})

t.Run("with vpc filter", func(t *testing.T) {
t.Parallel()
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
// First page must return at least pageSize number of DBs
var gotDatabases []types.Database
wantVPC := "vpc-2"
resp, err := ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err := ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand All @@ -188,7 +191,7 @@ func TestListDatabases(t *testing.T) {
gotDatabases = append(gotDatabases, resp.Databases...)

// Second page must return pageSize number of DBs
resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand All @@ -202,7 +205,7 @@ func TestListDatabases(t *testing.T) {
gotDatabases = append(gotDatabases, resp.Databases...)

// Third page must return only the remaining DBs and an empty nextToken
resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand Down Expand Up @@ -583,23 +586,77 @@ func TestListDatabases(t *testing.T) {
},

{
name: "cluster exists but no instance exists, returns an error",
name: "listing clusters returns all valid clusters and ignores the others",
req: ListDatabasesRequest{
Region: "us-east-1",
RDSType: "cluster",
Engines: []string{"postgres"},
NextToken: "",
},
mockClusters: []rdsTypes.DBCluster{{
Status: stringPointer("available"),
mockInstances: []rdsTypes.DBInstance{{
DBClusterIdentifier: stringPointer("my-dbc"),
DbClusterResourceId: stringPointer("db-123"),
Engine: stringPointer("aurora-postgresql"),
Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
DBSubnetGroup: &rdsTypes.DBSubnetGroup{
Subnets: []rdsTypes.Subnet{{SubnetIdentifier: aws.String("subnet-999")}},
VpcId: aws.String("vpc-999"),
},
}},
errCheck: trace.IsBadParameter,
mockClusters: []rdsTypes.DBCluster{
{
Status: stringPointer("available"),
DBClusterIdentifier: stringPointer("my-empty-cluster"),
DbClusterResourceId: stringPointer("db-456"),
Engine: stringPointer("aurora-mysql"),
Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
},
{
Status: stringPointer("available"),
DBClusterIdentifier: stringPointer("my-dbc"),
DbClusterResourceId: stringPointer("db-123"),
Engine: stringPointer("aurora-postgresql"),
Endpoint: stringPointer("aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
},
},
respCheck: func(t *testing.T, ldr *ListDatabasesResponse) {
require.Len(t, ldr.Databases, 1, "expected 1 database, got %d", len(ldr.Databases))
require.Empty(t, ldr.NextToken, "expected an empty NextToken")
expectedDB, err := types.NewDatabaseV3(
types.Metadata{
Name: "my-dbc",
Description: "Aurora cluster in ",
Labels: map[string]string{
"account-id": "123456789012",
"endpoint-type": "primary",
"engine": "aurora-postgresql",
"engine-version": "",
"region": "",
"status": "available",
"vpc-id": "vpc-999",
"teleport.dev/cloud": "AWS",
},
},
types.DatabaseSpecV3{
Protocol: "postgres",
URI: "aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
AWS: types.AWS{
AccountID: "123456789012",
RDS: types.RDS{
ClusterID: "my-dbc",
InstanceID: "aurora-instance-2",
ResourceID: "db-123",
Subnets: []string{"subnet-999"},
VPCID: "vpc-999",
},
},
},
)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0]))
},
errCheck: noErrorFunc,
},
{
name: "no region",
Expand Down Expand Up @@ -637,7 +694,8 @@ func TestListDatabases(t *testing.T) {
dbInstances: tt.mockInstances,
dbClusters: tt.mockClusters,
}
resp, err := ListDatabases(ctx, mockListClient, tt.req)
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
resp, err := ListDatabases(ctx, mockListClient, logger, tt.req)
require.True(t, tt.errCheck(err), "unexpected err: %v", err)
if err != nil {
return
Expand Down
Loading