Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
fix: Continue fetching on incorrect account permissions (#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbernays authored Jun 10, 2022
1 parent f43fe15 commit 71008d2
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 101 deletions.
140 changes: 78 additions & 62 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import (
wafv2types "github.com/aws/aws-sdk-go-v2/service/wafv2/types"
"github.com/aws/aws-sdk-go-v2/service/workspaces"
"github.com/aws/aws-sdk-go-v2/service/xray"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"github.com/cloudquery/cq-provider-sdk/provider/diag"
"github.com/cloudquery/cq-provider-sdk/provider/schema"
Expand All @@ -82,14 +83,14 @@ import (
type Client struct {
// Those are already normalized values after configure and this is why we don't want to hold
// config directly.
Accounts []Account
logLevel *string
maxRetries int
maxBackoff int
ServicesManager ServicesManager
logger hclog.Logger
// this is set by table clientList
AccountID string
GlobalRegion string
Region string
AutoscalingNamespace string
WAFScope wafv2types.Scope
Expand Down Expand Up @@ -174,17 +175,18 @@ type Services struct {
Xray XrayClient
}

type ServicesAccountRegionMap map[string]map[string]*Services
type ServicesPartitionAccountRegionMap map[string]map[string]map[string]*Services

// ServicesManager will hold the entire map of (account X region) services
type ServicesManager struct {
services ServicesAccountRegionMap
wafScopeServices map[string]*Services
services ServicesPartitionAccountRegionMap
wafScopeServices map[string]map[string]*Services
}

const (
defaultRegion = "us-east-1"
awsFailedToConfigureErrMsg = "failed to retrieve credentials for account %s. AWS Error: %w, detected aws env variables: %s"
awsOrgsFailedToFindMembers = "failed to list Org member accounts. Make sure that your credentials have the proper permissions"
defaultVar = "default"
cloudfrontScopeRegion = defaultRegion
)
Expand All @@ -209,29 +211,38 @@ var (
_ schema.ClientIdentifier = (*Client)(nil)
)

func (s *ServicesManager) ServicesByAccountAndRegion(accountId string, region string) *Services {
func (s *ServicesManager) ServicesByPartitionAccountAndRegion(partition, accountId, region string) *Services {
if region == "" {
region = defaultRegion
}
return s.services[accountId][region]
return s.services[partition][accountId][region]
}

func (s *ServicesManager) ServicesByAccountForWAFScope(accountId string) *Services {
return s.wafScopeServices[accountId]
func (s *ServicesManager) ServicesByAccountForWAFScope(partition, accountId string) *Services {
return s.wafScopeServices[partition][accountId]
}

func (s *ServicesManager) InitServicesForAccountAndRegion(accountId string, region string, services Services) {
if s.services[accountId] == nil {
s.services[accountId] = make(map[string]*Services)
func (s *ServicesManager) InitServicesForPartitionAccountAndRegion(partition, accountId, region string, services Services) {
if s.services == nil {
s.services = make(map[string]map[string]map[string]*Services)
}
s.services[accountId][region] = &services
if s.services[partition] == nil {
s.services[partition] = make(map[string]map[string]*Services)
}
if s.services[partition][accountId] == nil {
s.services[partition][accountId] = make(map[string]*Services)
}
s.services[partition][accountId][region] = &services
}

func (s *ServicesManager) InitServicesForAccountAndScope(accountId string, services Services) {
func (s *ServicesManager) InitServicesForPartitionAccountAndScope(partition, accountId string, services Services) {
if s.wafScopeServices == nil {
s.wafScopeServices = make(map[string]*Services)
s.wafScopeServices = make(map[string]map[string]*Services)
}
s.wafScopeServices[accountId] = &services
if s.wafScopeServices[partition] == nil {
s.wafScopeServices[partition] = make(map[string]*Services)
}
s.wafScopeServices[partition][accountId] = &services
}

func newS3ManagerFromConfig(cfg aws.Config) S3Manager {
Expand All @@ -244,17 +255,26 @@ func (s3Manager S3Manager) GetBucketRegion(ctx context.Context, bucket string, o
return manager.GetBucketRegion(ctx, s3Manager.s3Client, bucket, optFns...)
}

func NewAwsClient(logger hclog.Logger, accounts []Account) Client {
func NewAwsClient(logger hclog.Logger) Client {
return Client{
ServicesManager: ServicesManager{
services: ServicesAccountRegionMap{},
services: ServicesPartitionAccountRegionMap{},
},
logger: logger,
Accounts: accounts,
logger: logger,
}
}

func (s ServicesPartitionAccountRegionMap) Accounts() []string {
accounts := make([]string, 0)
for partitions := range s {
for account := range s[partitions] {
accounts = append(accounts, account)
}
}
return accounts
}
func (c *Client) Logger() hclog.Logger {
return &awsLogger{c.logger, c.Accounts}
return &awsLogger{c.logger, c.ServicesManager.services.Accounts()}
}

// Identify the given client
Expand All @@ -268,9 +288,9 @@ func (c *Client) Identify() string {
}

func (c *Client) Services() *Services {
s := c.ServicesManager.ServicesByAccountAndRegion(c.AccountID, c.Region)
s := c.ServicesManager.ServicesByPartitionAccountAndRegion(c.Partition, c.AccountID, c.Region)
if s == nil && c.WAFScope == wafv2types.ScopeCloudfront {
return c.ServicesManager.ServicesByAccountForWAFScope(c.AccountID)
return c.ServicesManager.ServicesByAccountForWAFScope(c.Partition, c.AccountID)
}
return s
}
Expand All @@ -290,25 +310,9 @@ func (c *Client) PartitionGlobalARN(service AWSService, idParts ...string) strin
return makeARN(service, c.Partition, "", "", idParts...).String()
}

func (c *Client) withAccountID(accountID string) *Client {
return &Client{
Partition: c.Partition,
Accounts: c.Accounts,
logLevel: c.logLevel,
maxRetries: c.maxRetries,
maxBackoff: c.maxBackoff,
ServicesManager: c.ServicesManager,
logger: c.logger.With("account_id", obfuscateAccountId(accountID)),
AccountID: accountID,
Region: c.Region,
AutoscalingNamespace: c.AutoscalingNamespace,
}
}

func (c *Client) withAccountIDAndRegion(accountID, region string) *Client {
func (c *Client) withPartitionAccountIDAndRegion(partition, accountID, region string) *Client {
return &Client{
Partition: c.Partition,
Accounts: c.Accounts,
Partition: partition,
logLevel: c.logLevel,
maxRetries: c.maxRetries,
maxBackoff: c.maxBackoff,
Expand All @@ -321,10 +325,9 @@ func (c *Client) withAccountIDAndRegion(accountID, region string) *Client {
}
}

func (c *Client) withAccountIDRegionAndNamespace(accountID, region, namespace string) *Client {
func (c *Client) withPartitionAccountIDRegionAndNamespace(partition, accountID, region, namespace string) *Client {
return &Client{
Partition: c.Partition,
Accounts: c.Accounts,
Partition: partition,
logLevel: c.logLevel,
maxRetries: c.maxRetries,
maxBackoff: c.maxBackoff,
Expand All @@ -337,10 +340,9 @@ func (c *Client) withAccountIDRegionAndNamespace(accountID, region, namespace st
}
}

func (c *Client) withAccountIDRegionAndScope(accountID, region string, scope wafv2types.Scope) *Client {
func (c *Client) withPartitionAccountIDRegionAndScope(partition, accountID, region string, scope wafv2types.Scope) *Client {
return &Client{
Partition: c.Partition,
Accounts: c.Accounts,
Partition: partition,
logLevel: c.logLevel,
maxRetries: c.maxRetries,
maxBackoff: c.maxBackoff,
Expand Down Expand Up @@ -413,7 +415,7 @@ func configureAwsClient(ctx context.Context, logger hclog.Logger, awsConfig *Con

if err != nil {
logger.Error("error loading default config", "err", err)
return awsCfg, fmt.Errorf(awsFailedToConfigureErrMsg, account.AccountName, err, checkEnvVariables())
return awsCfg, err
}

if account.RoleARN != "" {
Expand Down Expand Up @@ -444,7 +446,7 @@ func configureAwsClient(ctx context.Context, logger hclog.Logger, awsConfig *Con
// Test out retrieving credentials
if _, err := awsCfg.Credentials.Retrieve(ctx); err != nil {
logger.Error("error retrieving credentials", "err", err)
return awsCfg, classifyError(fmt.Errorf(awsFailedToConfigureErrMsg, account.AccountName, err, checkEnvVariables()), diag.INTERNAL, nil, diag.WithSeverity(diag.ERROR), diag.WithNoOverwrite())
return awsCfg, err
}

return awsCfg, err
Expand All @@ -455,14 +457,22 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe

ctx := context.Background()
awsConfig := providerConfig.(*Config)
client := NewAwsClient(logger, awsConfig.Accounts)
client := NewAwsClient(logger)
client.GlobalRegion = awsConfig.GlobalRegion
var adminAccountSts AssumeRoleAPIClient

if awsConfig.Organization != nil {
var err error
awsConfig.Accounts, adminAccountSts, err = loadOrgAccounts(ctx, logger, awsConfig)
if err != nil {
logger.Error("error getting child accounts", "err", err)

var ae smithy.APIError
if errors.As(err, &ae) {
if strings.Contains(ae.ErrorCode(), "AccessDenied") {
return nil, diags.Add(diag.FromError(fmt.Errorf(awsOrgsFailedToFindMembers), diag.ACCESS, diag.WithSeverity(diag.ERROR)))
}
}
return nil, diags.Add(classifyError(err, diag.INTERNAL, nil))
}
}
Expand Down Expand Up @@ -503,6 +513,14 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe
diags = diags.Add(diag.FromError(errors.New("unable to assume role in account"), diag.ACCESS, diag.WithSeverity(diag.WARNING)))
continue
}
var ae smithy.APIError
if errors.As(err, &ae) {
if strings.Contains(ae.ErrorCode(), "AccessDenied") {
diags = diags.Add(diag.FromError(fmt.Errorf(awsFailedToConfigureErrMsg, account.AccountName, err, checkEnvVariables()), diag.ACCESS, diag.WithSeverity(diag.WARNING)))
continue
}
}

return nil, diags.Add(diag.FromError(err, diag.ACCESS))
}

Expand All @@ -517,35 +535,33 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe
}
})
if err != nil {
return nil, diags.Add(classifyError(fmt.Errorf("failed to find disabled regions for account %s. AWS Error: %w", account.AccountName, err), diag.INTERNAL, nil, diag.WithSeverity(diag.ERROR), diag.WithNoOverwrite()))
diags = diags.Add(diag.FromError(fmt.Errorf("failed to find disabled regions for account %s. AWS Error: %w", account.AccountName, err), diag.ACCESS, diag.WithSeverity(diag.WARNING)))
continue
}
account.Regions = filterDisabledRegions(localRegions, res.Regions)

if len(account.Regions) == 0 {
return nil, diags.Add(diag.FromError(fmt.Errorf("no enabled regions provided in config for account %s", account.AccountName), diag.USER))
diags = diags.Add(diag.FromError(fmt.Errorf("no enabled regions provided in config for account %s", account.AccountName), diag.ACCESS, diag.WithSeverity(diag.WARNING)))
continue
}
awsCfg.Region = account.Regions[0]
output, err := getAccountId(ctx, awsCfg)
if err != nil {
return nil, diags.Add(classifyError(err, diag.INTERNAL, nil))
// return nil, diags.Add(classifyError(err, diag.INTERNAL, nil))
diags = diags.Add(diag.FromError(fmt.Errorf("failed to find disabled regions for account %s. AWS Error: %w", account.AccountName, err), diag.ACCESS, diag.WithSeverity(diag.WARNING)))
continue
}
iamArn, err := arn.Parse(*output.Arn)
if err != nil {
return nil, diags.Add(classifyError(err, diag.INTERNAL, nil))
}
if client.AccountID == "" {
// set default
client.AccountID = *output.Account
client.Region = account.Regions[0]
client.Partition = iamArn.Partition
client.Accounts = append(client.Accounts, Account{ID: *output.Account, RoleARN: *output.Arn})
}

for _, region := range account.Regions {
client.ServicesManager.InitServicesForAccountAndRegion(*output.Account, region, initServices(region, awsCfg))
client.ServicesManager.InitServicesForPartitionAccountAndRegion(iamArn.Partition, *output.Account, region, initServices(region, awsCfg))
}
client.ServicesManager.InitServicesForAccountAndScope(*output.Account, initServices(cloudfrontScopeRegion, awsCfg))
client.ServicesManager.InitServicesForPartitionAccountAndScope(iamArn.Partition, *output.Account, initServices(cloudfrontScopeRegion, awsCfg))
}
if len(client.Accounts) == 0 {
if len(client.ServicesManager.services) == 0 {
return nil, diags.Add(diag.FromError(errors.New("no accounts instantiated"), diag.USER))
}
return &client, diags
Expand Down
1 change: 1 addition & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Config struct {
AWSDebug bool `hcl:"aws_debug,optional"`
MaxRetries int `hcl:"max_retries,optional" default:"10"`
MaxBackoff int `hcl:"max_backoff,optional" default:"30"`
GlobalRegion string `hcl:"global_region,optional" default:"us-east-1"`
}

func (Config) Example() string {
Expand Down
12 changes: 6 additions & 6 deletions client/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ var throttleCodes = map[string]struct{}{
func ErrorClassifier(meta schema.ClientMeta, resourceName string, err error) diag.Diagnostics {
client := meta.(*Client)

return classifyError(err, diag.RESOLVING, client.Accounts, diag.WithResourceName(resourceName), includeResourceIdWithAccount(client, err))
return classifyError(err, diag.RESOLVING, client.ServicesManager.services.Accounts(), diag.WithResourceName(resourceName), includeResourceIdWithAccount(client, err))
}

func classifyError(err error, fallbackType diag.Type, accounts []Account, opts ...diag.BaseErrorOption) diag.Diagnostics {
func classifyError(err error, fallbackType diag.Type, accounts []string, opts ...diag.BaseErrorOption) diag.Diagnostics {
var ae smithy.APIError
if errors.As(err, &ae) {
switch ae.ErrorCode() {
Expand Down Expand Up @@ -199,7 +199,7 @@ func ParseSummaryMessage(err error) diag.BaseErrorOption {
}

// RedactError redacts a given diagnostic and returns a RedactedDiagnostic containing both original and redacted versions
func RedactError(aa []Account, e diag.Diagnostic) diag.Diagnostic {
func RedactError(aa []string, e diag.Diagnostic) diag.Diagnostic {
r := diag.NewBaseError(
nil,
e.Type(),
Expand Down Expand Up @@ -227,9 +227,9 @@ func isCodeThrottle(code string) bool {
return ok
}

func removePII(aa []Account, msg string) string {
for i := range aa {
msg = strings.ReplaceAll(msg, " AccountID "+aa[i].ID, " AccountID xxxx")
func removePII(aa []string, msg string) string {
for _, i := range aa {
msg = strings.ReplaceAll(msg, " AccountID "+i, " AccountID xxxx")
}
msg = requestIdRegex.ReplaceAllString(msg, " ${1} xxxx")
msg = hostIdRegex.ReplaceAllString(msg, " HostID: xxxx")
Expand Down
2 changes: 1 addition & 1 deletion client/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestRemovePII(t *testing.T) {
},
}
for i, tc := range cases {
res := removePII([]Account{{ID: "123456789"}}, tc.Input)
res := removePII([]string{"123456789"}, tc.Input)
assert.Equalf(t, tc.Expected, res, "Case #%d", i+1)
}
}
4 changes: 2 additions & 2 deletions client/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ func IgnoreNotAvailableRegion(err error) bool {
return false
}

func accountObfusactor(aa []Account, msg string) string {
func accountObfusactor(aa []string, msg string) string {
for _, a := range aa {
msg = strings.ReplaceAll(msg, a.ID, obfuscateAccountId(a.ID))
msg = strings.ReplaceAll(msg, a, obfuscateAccountId(a))
}
return msg
}
Expand Down
2 changes: 1 addition & 1 deletion client/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

type awsLogger struct {
hclog.Logger
accounts []Account
accounts []string
}

func (a awsLogger) Log(level hclog.Level, msg string, args ...interface{}) {
Expand Down
Loading

0 comments on commit 71008d2

Please sign in to comment.