Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v10] fix aws rds discovery invalid engine filter #18618

Merged
123 changes: 115 additions & 8 deletions lib/srv/db/cloud/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,25 @@ func (m *STSMock) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdenti
// RDSMock mocks AWS RDS API.
type RDSMock struct {
rdsiface.RDSAPI
DBInstances []*rds.DBInstance
DBClusters []*rds.DBCluster
DBInstances []*rds.DBInstance
DBClusters []*rds.DBCluster
DBEngineVersions []*rds.DBEngineVersion
}

func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) {
if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
return nil, trace.Wrap(err)
}
instances, err := applyInstanceFilters(m.DBInstances, input.Filters)
if err != nil {
return nil, trace.Wrap(err)
}
if aws.StringValue(input.DBInstanceIdentifier) == "" {
return &rds.DescribeDBInstancesOutput{
DBInstances: m.DBInstances,
DBInstances: instances,
}, nil
}
for _, instance := range m.DBInstances {
for _, instance := range instances {
if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) {
return &rds.DescribeDBInstancesOutput{
DBInstances: []*rds.DBInstance{instance},
Expand All @@ -78,19 +86,33 @@ func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.Des
}

func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error {
if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
return trace.Wrap(err)
}
instances, err := applyInstanceFilters(m.DBInstances, input.Filters)
if err != nil {
return trace.Wrap(err)
}
fn(&rds.DescribeDBInstancesOutput{
DBInstances: m.DBInstances,
DBInstances: instances,
}, true)
return nil
}

func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) {
if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
return nil, trace.Wrap(err)
}
clusters, err := applyClusterFilters(m.DBClusters, input.Filters)
if err != nil {
return nil, trace.Wrap(err)
}
if aws.StringValue(input.DBClusterIdentifier) == "" {
return &rds.DescribeDBClustersOutput{
DBClusters: m.DBClusters,
DBClusters: clusters,
}, nil
}
for _, cluster := range m.DBClusters {
for _, cluster := range clusters {
if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) {
return &rds.DescribeDBClustersOutput{
DBClusters: []*rds.DBCluster{cluster},
Expand All @@ -101,8 +123,15 @@ func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.Desc
}

func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil {
return trace.Wrap(err)
}
clusters, err := applyClusterFilters(m.DBClusters, input.Filters)
if err != nil {
return trace.Wrap(err)
}
fn(&rds.DescribeDBClustersOutput{
DBClusters: m.DBClusters,
DBClusters: clusters,
}, true)
return nil
}
Expand Down Expand Up @@ -529,3 +558,81 @@ func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.Upda
}
return nil, trace.NotFound("user %s not found", aws.StringValue(input.UserName))
}

// checkEngineFilters checks RDS filters to detect unrecognized engine filters.
func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error {
if len(filters) == 0 {
return nil
}
recognizedEngines := make(map[string]struct{})
for _, e := range engineVersions {
recognizedEngines[aws.StringValue(e.Engine)] = struct{}{}
}
for _, f := range filters {
if aws.StringValue(f.Name) != "engine" {
continue
}
for _, v := range f.Values {
if _, ok := recognizedEngines[aws.StringValue(v)]; !ok {
return trace.Errorf("unrecognized engine name %q", aws.StringValue(v))
}
}
}
return nil
}

// applyInstanceFilters filters RDS DBInstances using the provided RDS filters.
func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) {
if len(filters) == 0 {
return in, nil
}
var out []*rds.DBInstance
efs := engineFilterSet(filters)
for _, instance := range in {
if instanceEngineMatches(instance, efs) {
out = append(out, instance)
}
}
return out, nil
}

// applyClusterFilters filters RDS DBClusters using the provided RDS filters.
func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) {
if len(filters) == 0 {
return in, nil
}
var out []*rds.DBCluster
efs := engineFilterSet(filters)
for _, cluster := range in {
if clusterEngineMatches(cluster, efs) {
out = append(out, cluster)
}
}
return out, nil
}

// engineFilterSet builds a string set of engine names from a list of RDS filters.
func engineFilterSet(filters []*rds.Filter) map[string]struct{} {
out := make(map[string]struct{})
for _, f := range filters {
if aws.StringValue(f.Name) != "engine" {
continue
}
for _, v := range f.Values {
out[aws.StringValue(v)] = struct{}{}
}
}
return out
}

// instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set.
func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(instance.Engine)]
return ok
}

// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set.
func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(cluster.Engine)]
return ok
}
124 changes: 88 additions & 36 deletions lib/srv/db/cloud/watchers/rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.Databases, error

// getRDSDatabases returns a list of database resources representing RDS instances.
func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Databases, error) {
instances, err := getAllDBInstances(ctx, f.cfg.RDS, common.MaxPages)
instances, err := getAllDBInstances(ctx, f.cfg.RDS, common.MaxPages, f.log)
if err != nil {
return nil, common.ConvertError(err)
}
Expand Down Expand Up @@ -122,16 +122,25 @@ func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Data

// getAllDBInstances fetches all RDS instances using the provided client, up
// to the specified max number of pages.
func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (instances []*rds.DBInstance, err error) {
var pageNum int
err = rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{
Filters: rdsFilters(),
}, func(ddo *rds.DescribeDBInstancesOutput, lastPage bool) bool {
pageNum++
instances = append(instances, ddo.DBInstances...)
return pageNum <= maxPages
func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, log logrus.FieldLogger) ([]*rds.DBInstance, error) {
var instances []*rds.DBInstance
err := retryWithIndividualEngineFilters(log, rdsInstanceEngines(), func(filters []*rds.Filter) error {
var pageNum int
var out []*rds.DBInstance
err := rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{
Filters: filters,
}, func(ddo *rds.DescribeDBInstancesOutput, lastPage bool) bool {
pageNum++
instances = append(instances, ddo.DBInstances...)
return pageNum <= maxPages
})
if err == nil {
// only append to instances on nil error, just in case we have to retry.
instances = append(instances, out...)
}
return trace.Wrap(err)
})
return instances, common.ConvertError(err)
return instances, trace.Wrap(err)
}

// String returns the fetcher's string description.
Expand Down Expand Up @@ -173,7 +182,7 @@ func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.Databases, er

// getAuroraDatabases returns a list of database resources representing RDS clusters.
func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (types.Databases, error) {
clusters, err := getAllDBClusters(ctx, f.cfg.RDS, common.MaxPages)
clusters, err := getAllDBClusters(ctx, f.cfg.RDS, common.MaxPages, f.log)
if err != nil {
return nil, common.ConvertError(err)
}
Expand Down Expand Up @@ -248,16 +257,25 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type

// getAllDBClusters fetches all RDS clusters using the provided client, up to
// the specified max number of pages.
func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (clusters []*rds.DBCluster, err error) {
var pageNum int
err = rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
Filters: auroraFilters(),
}, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool {
pageNum++
clusters = append(clusters, ddo.DBClusters...)
return pageNum <= maxPages
func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, log logrus.FieldLogger) ([]*rds.DBCluster, error) {
var clusters []*rds.DBCluster
err := retryWithIndividualEngineFilters(log, auroraEngines(), func(filters []*rds.Filter) error {
var pageNum int
var out []*rds.DBCluster
err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
Filters: filters,
}, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool {
pageNum++
out = append(out, ddo.DBClusters...)
return pageNum <= maxPages
})
if err == nil {
// only append to clusters on nil error, just in case we have to retry.
clusters = append(clusters, out...)
}
return trace.Wrap(err)
})
return clusters, common.ConvertError(err)
return clusters, trace.Wrap(err)
}

// String returns the fetcher's string description.
Expand All @@ -266,26 +284,60 @@ func (f *rdsAuroraClustersFetcher) String() string {
f.cfg.Region, f.cfg.Labels)
}

// rdsFilters returns filters to make sure DescribeDBInstances call returns
// rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns
// only databases with engines Teleport supports.
func rdsFilters() []*rds.Filter {
return []*rds.Filter{{
Name: aws.String("engine"),
Values: aws.StringSlice([]string{
services.RDSEnginePostgres,
services.RDSEngineMySQL,
services.RDSEngineMariaDB}),
}}
func rdsInstanceEngines() []string {
return []string{
services.RDSEnginePostgres,
services.RDSEngineMySQL,
services.RDSEngineMariaDB,
}
}

// auroraFilters returns filters to make sure DescribeDBClusters call returns
// auroraEngines returns engines to make sure DescribeDBClusters call returns
// only databases with engines Teleport supports.
func auroraFilters() []*rds.Filter {
func auroraEngines() []string {
return []string{
services.RDSEngineAurora,
services.RDSEngineAuroraMySQL,
services.RDSEngineAuroraPostgres,
}
}

// rdsEngineFilter is a helper func to construct an RDS filter for engine names.
func rdsEngineFilter(engines []string) []*rds.Filter {
return []*rds.Filter{{
Name: aws.String("engine"),
Values: aws.StringSlice([]string{
services.RDSEngineAurora,
services.RDSEngineAuroraMySQL,
services.RDSEngineAuroraPostgres}),
Name: aws.String("engine"),
Values: aws.StringSlice(engines),
}}
}

// rdsFilterFn is a function that takes RDS filters and performs some operation with them, returning any error encountered.
type rdsFilterFn func([]*rds.Filter) error

// retryWithIndividualEngineFilters is a helper error handling function for AWS RDS unrecognized engine name filter errors,
// that will call the provided RDS querying function with filters, check the returned error,
// and if the error is an AWS unrecognized engine name error then it will retry once by calling the function with one filter
// at a time. If any error other than an AWS unrecognized engine name error occurs, this function will return that error
// without retrying, or skip retrying subsequent filters if it has already started to retry.
func retryWithIndividualEngineFilters(log logrus.FieldLogger, engines []string, fn rdsFilterFn) error {
err := fn(rdsEngineFilter(engines))
if err == nil {
return nil
}
if !common.IsUnrecognizedAWSEngineNameError(err) {
return trace.Wrap(err)
}
log.WithError(err).Warn("Teleport supports an engine which is unrecognized in this AWS region. Querying engines individually.")
for _, engine := range engines {
err := fn(rdsEngineFilter([]string{engine}))
if err == nil {
continue
}
if !common.IsUnrecognizedAWSEngineNameError(err) {
return trace.Wrap(err)
}
// skip logging unrecognized engine name error here, we already logged it in the initial attempt.
}
return nil
}
Loading