From 8ccea2e254c9e82ee415046d5022b35e1c69dce7 Mon Sep 17 00:00:00 2001 From: Ben Drucker Date: Fri, 13 Aug 2021 07:11:55 -0700 Subject: [PATCH] Validate credentials in provider configure func (#579) * Revert "Merge pull request #571 from okta/panic_when_unauth" This reverts commit 999770566e5270a4e00c0df9db86b22f0c9ba9b9, reversing changes made to 3d825ade52f93b601fc853ba2730eb7b0361284f. * provider: validate credentials during initial config * return all errors * Keep some changes Co-authored-by: Bogdan Prodan --- okta/config.go | 11 ++++++--- okta/internal/transport/auth.go | 43 --------------------------------- okta/provider.go | 6 ++--- 3 files changed, 11 insertions(+), 49 deletions(-) delete mode 100644 okta/internal/transport/auth.go diff --git a/okta/config.go b/okta/config.go index 4107abdbb..d622a5206 100644 --- a/okta/config.go +++ b/okta/config.go @@ -51,7 +51,7 @@ type ( } ) -func (c *Config) loadAndValidate() error { +func (c *Config) loadAndValidate(ctx context.Context) error { logLevel := hclog.Level(c.logLevel) if os.Getenv("TF_LOG") != "" { logLevel = hclog.LevelFromString(os.Getenv("TF_LOG")) @@ -68,14 +68,14 @@ func (c *Config) loadAndValidate() error { retryableClient.RetryWaitMax = time.Second * time.Duration(c.maxWait) retryableClient.RetryMax = c.retryCount retryableClient.Logger = c.logger - retryableClient.HTTPClient.Transport = transport.NewAuthTransport(logging.NewTransport("Okta", retryableClient.HTTPClient.Transport)) + retryableClient.HTTPClient.Transport = logging.NewTransport("Okta", retryableClient.HTTPClient.Transport) retryableClient.ErrorHandler = errHandler retryableClient.CheckRetry = checkRetry httpClient = retryableClient.StandardClient() c.logger.Info(fmt.Sprintf("running with backoff http client, wait min %d, wait max %d, retry max %d", retryableClient.RetryWaitMin, retryableClient.RetryWaitMax, retryableClient.RetryMax)) } else { httpClient = cleanhttp.DefaultClient() - httpClient.Transport = transport.NewAuthTransport(logging.NewTransport("Okta", httpClient.Transport)) + httpClient.Transport = logging.NewTransport("Okta", httpClient.Transport) c.logger.Info("running with default http client") } @@ -112,6 +112,11 @@ func (c *Config) loadAndValidate() error { if err != nil { return err } + + if _, _, err := client.User.GetUser(ctx, "me"); err != nil { + return fmt.Errorf("invalid credentials: %w", err) + } + c.oktaClient = client c.supplementClient = &sdk.APISupplement{ RequestExecutor: client.GetRequestExecutor(), diff --git a/okta/internal/transport/auth.go b/okta/internal/transport/auth.go deleted file mode 100644 index 82473f143..000000000 --- a/okta/internal/transport/auth.go +++ /dev/null @@ -1,43 +0,0 @@ -package transport - -import ( - "net/http" - "strings" - - "github.com/okta/okta-sdk-golang/v2/okta" -) - -type AuthTransport struct { - base http.RoundTripper -} - -// NewAuthTransport stops the provider execution in case Okta API returns either 401 or 403 error codes. -func NewAuthTransport(base http.RoundTripper) *AuthTransport { - return &AuthTransport{ - base: base, - } -} - -// RoundTrip read the code based on the response from the API and terminates any further -// execution in case of unauthenticated or unauthorized requests. -// This can be invalid API token or a private key, no permissions to execute the request, -// no correct scopes were granted,etc. -func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := t.base.RoundTrip(req) - if err != nil { - return nil, err - } - if resp.StatusCode == http.StatusBadRequest { - oktaErr := okta.CheckResponseForError(resp).Error() - if strings.Contains(oktaErr, "You are not allowed any of the requested scopes") || - strings.Contains(oktaErr, "Invalid value for 'client_id' parameter") { - // panic here because hlog doesn't have Fatal method - panic(oktaErr) - } - } - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - // panic here because hlog doesn't have Fatal method - panic(okta.CheckResponseForError(resp)) - } - return resp, nil -} diff --git a/okta/provider.go b/okta/provider.go index 54a984d22..fc9b5928d 100644 --- a/okta/provider.go +++ b/okta/provider.go @@ -313,7 +313,7 @@ func deprecateIncorrectNaming(d *schema.Resource, newResource string) *schema.Re return d } -func providerConfigure(_ context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { +func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { log.Printf("[INFO] Initializing Okta client") config := Config{ orgName: d.Get("org_name").(string), @@ -334,8 +334,8 @@ func providerConfigure(_ context.Context, d *schema.ResourceData) (interface{}, if v := os.Getenv("OKTA_API_SCOPES"); v != "" && len(config.scopes) == 0 { config.scopes = strings.Split(v, ",") } - if err := config.loadAndValidate(); err != nil { - return nil, diag.Errorf("[ERROR] Error initializing the Okta SDK clients: %v", err) + if err := config.loadAndValidate(ctx); err != nil { + return nil, diag.Errorf("[ERROR] invalid configuration: %v", err) } return &config, nil }