From 42e48cccffc44fa104b7c101a10ee2eeb44c1651 Mon Sep 17 00:00:00 2001 From: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> Date: Thu, 2 May 2024 19:10:02 -0500 Subject: [PATCH] Adds callback mode that is direct to vault Signed-off-by: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> --- backend.go | 1 + cli.go | 130 ++++++++++--- cli_responses.go | 461 ---------------------------------------------- path_oidc.go | 179 +++++++++++++----- path_oidc_test.go | 68 +++++-- path_role.go | 31 ++++ path_role_test.go | 3 + 7 files changed, 327 insertions(+), 546 deletions(-) delete mode 100644 cli_responses.go diff --git a/backend.go b/backend.go index 85041de0..dc09966a 100644 --- a/backend.go +++ b/backend.go @@ -63,6 +63,7 @@ func backend() *jwtAuthBackend { "login", "oidc/auth_url", "oidc/callback", + "oidc/poll", // Uncomment to mount simple UI handler for local development // "ui", diff --git a/cli.go b/cli.go index 9c61f868..b026054f 100644 --- a/cli.go +++ b/cli.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + "net/url" "os" "os/signal" "path" @@ -27,9 +28,11 @@ const ( defaultPort = "8250" defaultCallbackHost = "localhost" defaultCallbackMethod = "http" + defaultCallbackMode = "client" FieldCallbackHost = "callbackhost" FieldCallbackMethod = "callbackmethod" + FieldCallbackMode = "callbackmode" FieldListenAddress = "listenaddress" FieldPort = "port" FieldCallbackPort = "callbackport" @@ -69,19 +72,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro port = defaultPort } + var vaultURL *url.URL + callbackMode, ok := m[FieldCallbackMode] + if !ok { + callbackMode = defaultCallbackMode + } else if callbackMode == "direct" { + vaultAddr := os.Getenv("VAULT_ADDR") + if vaultAddr != "" { + vaultURL, _ = url.Parse(vaultAddr) + } + } + callbackHost, ok := m[FieldCallbackHost] if !ok { - callbackHost = defaultCallbackHost + if vaultURL != nil { + callbackHost = vaultURL.Hostname() + } else { + callbackHost = defaultCallbackHost + } } callbackMethod, ok := m[FieldCallbackMethod] if !ok { - callbackMethod = defaultCallbackMethod + if vaultURL != nil { + callbackMethod = vaultURL.Scheme + } else { + callbackMethod = defaultCallbackMethod + } } callbackPort, ok := m[FieldCallbackPort] if !ok { - callbackPort = port + if vaultURL != nil { + callbackPort = vaultURL.Port() + "/v1/auth/" + mount + } else { + callbackPort = port + } } parseBool := func(f string, d bool) (bool, error) { @@ -115,20 +141,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro role := m["role"] - authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost) + authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost) if err != nil { return nil, err } - // Set up callback handler doneCh := make(chan loginResp) - http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh)) - listener, err := net.Listen("tcp", listenAddress+":"+port) - if err != nil { - return nil, err + var pollInterval string + var interval int + var state string + var listener net.Listener + + if secret != nil { + pollInterval, _ = secret.Data["poll_interval"].(string) + state, _ = secret.Data["state"].(string) + } + if callbackMode == "direct" { + if state == "" { + return nil, errors.New("no state returned in direct callback mode") + } + if pollInterval == "" { + return nil, errors.New("no poll_interval returned in direct callback mode") + } + interval, err = strconv.Atoi(pollInterval) + if err != nil { + return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer") + } + } else { + if state != "" { + return nil, errors.New("state returned in client callback mode, try direct") + } + if pollInterval != "" { + return nil, errors.New("poll_interval returned in client callback mode") + } + // Set up callback handler + http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh)) + + listener, err := net.Listen("tcp", listenAddress+":"+port) + if err != nil { + return nil, err + } + defer listener.Close() } - defer listener.Close() // Open the default browser to the callback URL. if !skipBrowserLaunch { @@ -144,6 +199,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro } fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n") + if callbackMode == "direct" { + data := map[string]interface{}{ + "state": state, + "client_nonce": clientNonce, + } + pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount) + for { + time.Sleep(time.Duration(interval) * time.Second) + + secret, err := c.Logical().Write(pollUrl, data) + if err == nil { + return secret, nil + } + if !strings.HasSuffix(err.Error(), "authorization_pending") { + return nil, err + } + // authorization is pending, try again + } + } + // Start local server go func() { err := http.Serve(listener, nil) @@ -210,12 +285,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha } } -func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) { +func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) { var authURL string clientNonce, err := base62.Random(20) if err != nil { - return "", "", err + return "", "", nil, err } redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort) @@ -227,7 +302,7 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data) if err != nil { - return "", "", err + return "", "", nil, err } if secret != nil { @@ -235,10 +310,10 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho } if authURL == "" { - return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI) + return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI) } - return authURL, clientNonce, nil + return authURL, clientNonce, secret, nil } // parseError converts error from the API into summary and detailed portions. @@ -292,8 +367,8 @@ Usage: vault login -method=oidc [CONFIG K=V...] https://accounts.google.com/o/oauth2/v2/... - The default browser will be opened for the user to complete the login. Alternatively, - the user may visit the provided URL directly. + The default browser will be opened for the user to complete the login. + Alternatively, the user may visit the provided URL directly. Configuration: @@ -301,19 +376,29 @@ Configuration: Vault role of type "OIDC" to use for authentication. %s= - Optional address to bind the OIDC callback listener to (default: localhost). + Mode of callback: "direct" for direct connection to Vault or "client" + for connection to command line client (default: client). + + %s= + Optional address to bind the OIDC callback listener to in client callback + mode (default: localhost). %s= - Optional localhost port to use for OIDC callback (default: 8250). + Optional localhost port to use for OIDC callback in client callback mode + (default: 8250). %s= - Optional method to to use in OIDC redirect_uri (default: http). + Optional method to use in OIDC redirect_uri (default: the method from + $VAULT_ADDR in direct callback mode, else http) %s= - Optional callback host address to use in OIDC redirect_uri (default: localhost). + Optional callback host address to use in OIDC redirect_uri (default: + the host from $VAULT_ADDR in direct callback mode, else localhost). %s= - Optional port to to use in OIDC redirect_uri (default: the value set for port). + Optional port to use in OIDC redirect_uri (default: the value set for + port in client callback mode, else the port from $VAULT_ADDR with an + added /v1/auth/ where is from the login -path option). %s= Toggle the automatic launching of the default browser to the login URL. (default: false). @@ -321,6 +406,7 @@ Configuration: %s= Abort on any error. (default: false). `, + FieldCallbackMode, FieldListenAddress, FieldPort, FieldCallbackMethod, FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser, FieldAbortOnError, diff --git a/cli_responses.go b/cli_responses.go deleted file mode 100644 index afed8a5c..00000000 --- a/cli_responses.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: MPL-2.0 - -package jwtauth - -import "fmt" - -const successHTML = ` - - - - - - Vault Authentication Succeeded - - - -
-
- -
- -
-
- Signed in via your OIDC provider -
-

- You can now close this window and start using Vault. -

-
-
-
-

Not sure how to get started?

-

- Check out beginner and advanced guides on HashiCorp Vault at the HashiCorp Learn site or read more in the official documentation. -

- - - - - - - Get started with Vault - - - - - - - - View the official Vault documentation - -
-
- - -` - -func errorHTML(summary, detail string) string { - const html = ` - - - - - - - -HashiCorp Vault - - - -
-
- -
- - - -
-
- %s -
-

- %s -

-
-
-
- -

Not sure how to get started?

-

- Check out beginner and advanced guides on HashiCorp Vault at the HashiCorp Learn site or read more in the official documentation. -

- - - - - - - Get started with Vault - - - - - - - - - - View the official Vault documentation - -
-
- - - -` - return fmt.Sprintf(html, summary, detail) -} - -func formpostHTML(path, code, state string) string { - const html = ` - - - - - - Complete sign-in process - - - -
-
- -
- -
-
- Completing the sign-in process... -
-
-
-
-

Not sure how to get started?

-

- Check out beginner and advanced guides on HashiCorp Vault at the HashiCorp Learn site or read more in the official documentation. -

- - - - - - - Get started with Vault - - - - - - - - View the official Vault documentation - -
-
- - - -` - return fmt.Sprintf(html, path, code, state) -} diff --git a/path_oidc.go b/path_oidc.go index a34a7fa9..6f7304d9 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -6,7 +6,6 @@ package jwtauth import ( "context" "encoding/json" - "errors" "fmt" "net" "net/http" @@ -52,6 +51,9 @@ type oidcRequest struct { // clientNonce is used between Vault and the client/application (e.g. CLI) making the request, // and is unrelated to the OIDC nonce above. It is optional. clientNonce string + + // this is for storing the response in direct callback mode + auth *logical.Auth } func pathOIDC(b *jwtAuthBackend) []*framework.Path { @@ -82,6 +84,9 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { Type: framework.TypeString, Query: true, }, + "error_description": { + Type: framework.TypeString, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -105,6 +110,26 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, }, }, + { + Pattern: `oidc/poll`, + Fields: map[string]*framework.FieldSchema{ + "state": { + Type: framework.TypeString, + }, + "client_nonce": { + Type: framework.TypeString, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathPoll, + Summary: "Poll endpoint to complete an OIDC login.", + + // state is cached so don't process OIDC logins on perf standbys + ForwardPerformanceStandby: true, + }, + }, + }, { Pattern: `oidc/auth_url`, @@ -125,7 +150,7 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, "client_nonce": { Type: framework.TypeString, - Description: "Optional client-provided nonce that must match during callback, if present.", + Description: "Client-provided nonce that must match during callback, if present. Required only in direct callback mode.", }, }, @@ -167,11 +192,14 @@ func (b *jwtAuthBackend) pathCallbackPost(ctx context.Context, req *logical.Requ } // Store the provided code and/or token into its OIDC request, which must already exist. - oidcReq, err := b.amendOIDCRequest(stateID, code, idToken) - if err != nil { + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil { resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Expired or missing OAuth state.")) resp.Data[logical.HTTPStatusCode] = http.StatusBadRequest } else { + oidcReq.code = code + oidcReq.idToken = idToken + b.setOIDCRequest(stateID, oidcReq) mount := parseMount(oidcReq.RedirectURL()) if mount == "" { resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Invalid redirect path.")) @@ -184,6 +212,19 @@ func (b *jwtAuthBackend) pathCallbackPost(ctx context.Context, req *logical.Requ return resp, nil } +func loginFailedResponse(useHttp bool, msg string) *logical.Response { + if !useHttp { + return logical.ErrorResponse(errLoginFailed + " " + msg) + } + return &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPContentType: "text/html", + logical.HTTPStatusCode: http.StatusBadRequest, + logical.HTTPRawBody: []byte(errorHTML(errLoginFailed, msg)), + }, + } +} + func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { config, err := b.config(ctx, req.Storage) if err != nil { @@ -195,28 +236,45 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, stateID := d.Get("state").(string) - oidcReq := b.verifyOIDCRequest(stateID) - if oidcReq == nil { + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil || oidcReq.auth != nil { return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } - clientNonce := d.Get("client_nonce").(string) - - // If a client_nonce was provided at the start of the auth process as part of the auth_url - // request, require that it is present and matching during the callback phase. - if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { - return logical.ErrorResponse("invalid client_nonce"), nil - } - roleName := oidcReq.rolename role, err := b.role(ctx, req.Storage, roleName) if err != nil { + b.deleteOIDCRequest(stateID) return nil, err } if role == nil { + b.deleteOIDCRequest(stateID) return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil } + useHttp := false + if role.CallbackMode == callbackModeDirect { + useHttp = true + } + if !useHttp { + // state is only accessed once when not using direct callback + b.deleteOIDCRequest(stateID) + } + + errorDescription := d.Get("error_description").(string) + if errorDescription != "" { + return loginFailedResponse(useHttp, errorDescription), nil + } + + clientNonce := d.Get("client_nonce").(string) + + // If a client_nonce was provided at the start of the auth process as part of the auth_url + // request, require that it is present and matching during the callback phase + // unless using the direct callback mode (when we instead check in poll). + if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce && !useHttp { + return logical.ErrorResponse("invalid client_nonce"), nil + } + if len(role.TokenBoundCIDRs) > 0 { if req.Connection == nil { b.Logger().Warn("token bound CIDRs found but no connection information available for validation") @@ -242,7 +300,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, if code == "" { if oidcReq.idToken == "" { - return logical.ErrorResponse(errLoginFailed + " No code or id_token received."), nil + return loginFailedResponse(useHttp, "No code or id_token received."), nil } // Verify the ID token received from the authentication response. @@ -255,7 +313,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, // ID token verification takes place in provider.Exchange. token, err = provider.Exchange(ctx, oidcReq, stateID, code) if err != nil { - return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil + return loginFailedResponse(useHttp, fmt.Sprintf("Error exchanging oidc code: %q.", err.Error())), nil } rawToken = token.IDToken() @@ -287,7 +345,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } if role.BoundSubject != "" && role.BoundSubject != subject { - return nil, errors.New("sub claim does not match bound subject") + return loginFailedResponse(useHttp, "sub claim does not match bound subject"), nil } // Set the token source for the access token if it's available. It will only @@ -321,11 +379,11 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, tokenSource) if err != nil { - return logical.ErrorResponse(err.Error()), nil + return loginFailedResponse(useHttp, err.Error()), nil } if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil { - return logical.ErrorResponse("error validating claims: %s", err.Error()), nil + return loginFailedResponse(useHttp, fmt.Sprintf("error validating claims: %s", err.Error())), nil } tokenMetadata := make(map[string]string) @@ -354,13 +412,49 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, role.PopulateTokenAuth(auth) - resp := &logical.Response{ - Auth: auth, + resp := &logical.Response{} + if useHttp { + oidcReq.auth = auth + b.setOIDCRequest(stateID, oidcReq) + resp.Data = map[string]interface{}{ + logical.HTTPContentType: "text/html", + logical.HTTPStatusCode: http.StatusOK, + logical.HTTPRawBody: []byte(successHTML), + } + } else { + resp.Auth = auth } return resp, nil } +func (b *jwtAuthBackend) pathPoll(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + stateID := d.Get("state").(string) + + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil { + return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil + } + + clientNonce := d.Get("client_nonce").(string) + + if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { + b.deleteOIDCRequest(stateID) + return logical.ErrorResponse("invalid client_nonce"), nil + } + + if oidcReq.auth == nil { + // Return the same response as oauth 2.0 device flow in RFC8628 + return logical.ErrorResponse("authorization_pending"), nil + } + + b.deleteOIDCRequest(stateID) + resp := &logical.Response{ + Auth: oidcReq.auth, + } + return resp, nil +} + // authURL returns a URL used for redirection to receive an authorization code. // This path requires a role name, or that a default_role has been configured. // Because this endpoint is unauthenticated, the response to invalid or non-OIDC @@ -400,8 +494,6 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return logical.ErrorResponse("missing redirect_uri"), nil } - clientNonce := d.Get("client_nonce").(string) - role, err := b.role(ctx, req.Storage, roleName) if err != nil { return nil, err @@ -410,9 +502,14 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return logical.ErrorResponse("role %q could not be found", roleName), nil } - // If namespace will be passed around in state, and it has been provided as + clientNonce := d.Get("client_nonce").(string) + if clientNonce == "" && role.CallbackMode == callbackModeDirect { + return logical.ErrorResponse("missing client_nonce"), nil + } + + // If namespace will be passed around in oidcReq, and it has been provided as // a redirectURI query parameter, remove it from redirectURI, and append it - // to the state (later in this function) + // to the oidcReq (later in this function) namespace := "" if config.NamespaceInState { inputURI, err := url.Parse(redirectURI) @@ -460,13 +557,17 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return resp, nil } - // embed namespace in state in the auth_url + // embed namespace in oidcReq in the auth_url if config.NamespaceInState && len(namespace) > 0 { stateWithNamespace := fmt.Sprintf("%s,ns=%s", oidcReq.State(), namespace) urlStr = strings.Replace(urlStr, oidcReq.State(), url.QueryEscape(stateWithNamespace), 1) } resp.Data["auth_url"] = urlStr + if role.CallbackMode == callbackModeDirect { + resp.Data["state"] = oidcReq.State() + resp.Data["poll_interval"] = "5" + } return resp, nil } @@ -509,32 +610,14 @@ func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rol return oidcReq, nil } -func (b *jwtAuthBackend) amendOIDCRequest(stateID, code, idToken string) (*oidcRequest, error) { - requestRaw, ok := b.oidcRequests.Get(stateID) - if !ok { - return nil, errors.New("OIDC state not found") - } - - oidcReq := requestRaw.(*oidcRequest) - oidcReq.code = code - oidcReq.idToken = idToken - +func (b *jwtAuthBackend) setOIDCRequest(stateID string, oidcReq *oidcRequest) { b.oidcRequests.SetDefault(stateID, oidcReq) - - return oidcReq, nil } -// verifyOIDCRequest tests whether the provided state ID is valid and returns the -// associated oidcRequest if so. A nil oidcRequest is returned if the ID is not found -// or expired. The oidcRequest should only ever be retrieved once and is deleted as -// part of this request. -func (b *jwtAuthBackend) verifyOIDCRequest(stateID string) *oidcRequest { - defer b.oidcRequests.Delete(stateID) - +func (b *jwtAuthBackend) getOIDCRequest(stateID string) *oidcRequest { if requestRaw, ok := b.oidcRequests.Get(stateID); ok { return requestRaw.(*oidcRequest) } - return nil } @@ -549,6 +632,10 @@ func isLocalAddr(hostname string) bool { return hostname == "localhost" } +func (b *jwtAuthBackend) deleteOIDCRequest(stateID string) { + b.oidcRequests.Delete(stateID) +} + // validRedirect checks whether uri is in allowed using special handling for loopback uris. // Ref: https://tools.ietf.org/html/rfc8252#section-7.3 func validRedirect(uri string, allowed []string) bool { diff --git a/path_oidc_test.go b/path_oidc_test.go index 2b66a1ac..1354c140 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -523,7 +523,7 @@ func TestOIDC_AuthURL_max_age(t *testing.T) { // pointer syntax for the user_claim of roles. For claims used // in assertions, see the sampleClaims function. func TestOIDC_UserClaim_JSON_Pointer(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() type args struct { @@ -769,14 +769,27 @@ func TestOIDC_ResponseTypeIDToken(t *testing.T) { func TestOIDC_Callback(t *testing.T) { t.Run("successful login", func(t *testing.T) { // run test with and without bound_cidrs configured - for _, useBoundCIDRs := range []bool{false, true} { - b, storage, s := getBackendAndServer(t, useBoundCIDRs) + // and with and without direct callback mode + for i := 1; i <= 3; i++ { + var useBoundCIDRs bool + var callbackMode string + + if i == 2 { + useBoundCIDRs = true + } else if i == 3 { + callbackMode = "direct" + } + + b, storage, s := getBackendAndServer(t, useBoundCIDRs, callbackMode) defer s.server.Close() + clientNonce := "456" + // get auth_url data := map[string]interface{}{ "role": "test", "redirect_uri": "https://example.com", + "client_nonce": clientNonce, } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -811,8 +824,9 @@ func TestOIDC_Callback(t *testing.T) { Path: "oidc/callback", Storage: storage, Data: map[string]interface{}{ - "state": state, - "code": "abc", + "state": state, + "code": "abc", + "client_nonce": clientNonce, }, Connection: &logical.Connection{ RemoteAddr: "127.0.0.42", @@ -824,6 +838,22 @@ func TestOIDC_Callback(t *testing.T) { t.Fatal(err) } + if callbackMode == "direct" { + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "oidc/poll", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "client_nonce": clientNonce, + }, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + } + expected := &logical.Auth{ LeaseOptions: logical.LeaseOptions{ Renewable: true, @@ -870,7 +900,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed login - bad nonce", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -924,7 +954,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed login - bound claim mismatch", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -980,7 +1010,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("missing state", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() req := &logical.Request{ @@ -999,7 +1029,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("unknown state", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() req := &logical.Request{ @@ -1021,7 +1051,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("valid state, missing code", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1063,7 +1093,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed code exchange", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1113,7 +1143,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed code exchange (PKCE)", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1165,7 +1195,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("no response from provider", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") // get auth_url data := map[string]interface{}{ @@ -1211,7 +1241,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("test bad address", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, true) + b, storage, s := getBackendAndServer(t, true, "") defer s.server.Close() s.code = "abc" @@ -1256,7 +1286,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("test invalid client_id", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() s.code = "abc" @@ -1312,7 +1342,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("client_nonce", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // General behavior is that if a client_nonce is provided during the authURL phase @@ -1583,7 +1613,7 @@ func TestOIDC_ValidRedirect(t *testing.T) { } } -func getBackendAndServer(t *testing.T, boundCIDRs bool) (logical.Backend, logical.Storage, *oidcProvider) { +func getBackendAndServer(t *testing.T, boundCIDRs bool, callbackMode string) (logical.Backend, logical.Storage, *oidcProvider) { b, storage := getBackend(t) s := newOIDCProvider(t) s.clientID = "abc" @@ -1642,6 +1672,10 @@ func getBackendAndServer(t *testing.T, boundCIDRs bool) (logical.Backend, logica data["bound_cidrs"] = "127.0.0.42" } + if callbackMode != "" { + data["callback_mode"] = callbackMode + } + req = &logical.Request{ Operation: logical.CreateOperation, Path: "role/test", diff --git a/path_role.go b/path_role.go index 188f74d6..6e1b33e6 100644 --- a/path_role.go +++ b/path_role.go @@ -24,6 +24,8 @@ const ( claimDefaultLeeway = 150 boundClaimsTypeString = "string" boundClaimsTypeGlob = "glob" + callbackModeDirect = "direct" + callbackModeClient = "client" ) func pathRoleList(b *jwtAuthBackend) *framework.Path { @@ -154,6 +156,11 @@ for referencing claims.`, Type: framework.TypeCommaStringSlice, Description: `Comma-separated list of allowed values for redirect_uri`, }, + "callback_mode": { + Type: framework.TypeString, + Description: `OIDC callback mode from Authorization Server: allowed values are 'direct' to Vault or 'client', default 'client'`, + Default: callbackModeClient, + }, "verbose_oidc_logging": { Type: framework.TypeBool, Description: `Log received OIDC tokens and claims when debug-level logging is active. @@ -222,6 +229,7 @@ type jwtRole struct { GroupsClaim string `json:"groups_claim"` OIDCScopes []string `json:"oidc_scopes"` AllowedRedirectURIs []string `json:"allowed_redirect_uris"` + CallbackMode string `json:"callback_mode"` VerboseOIDCLogging bool `json:"verbose_oidc_logging"` MaxAge time.Duration `json:"max_age"` UserClaimJSONPointer bool `json:"user_claim_json_pointer"` @@ -330,6 +338,7 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, "user_claim_json_pointer": role.UserClaimJSONPointer, "groups_claim": role.GroupsClaim, "allowed_redirect_uris": role.AllowedRedirectURIs, + "callback_mode": role.CallbackMode, "oidc_scopes": role.OIDCScopes, "verbose_oidc_logging": role.VerboseOIDCLogging, "max_age": int64(role.MaxAge.Seconds()), @@ -356,6 +365,20 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, d["num_uses"] = role.NumUses } + if role.CallbackMode == "" { + // Must have been after an upgrade. Store the default value. + role.CallbackMode = "client" + d["callback_mode"] = role.CallbackMode + + entry, err := logical.StorageEntryJSON(rolePrefix+roleName, role) + if err != nil { + return nil, err + } + if err = req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + } + return &logical.Response{ Data: d, }, nil @@ -541,6 +564,14 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role.AllowedRedirectURIs = allowedRedirectURIs.([]string) } + callbackMode := data.Get("callback_mode").(string) + switch callbackMode { + case callbackModeDirect, callbackModeClient: + role.CallbackMode = callbackMode + default: + return logical.ErrorResponse("invalid 'callback_mode': %s", callbackMode), nil + } + if role.RoleType == "oidc" && len(role.AllowedRedirectURIs) == 0 { return logical.ErrorResponse( "'allowed_redirect_uris' must be set if 'role_type' is 'oidc' or unspecified."), nil diff --git a/path_role_test.go b/path_role_test.go index 5628433f..96d56f60 100644 --- a/path_role_test.go +++ b/path_role_test.go @@ -91,6 +91,7 @@ func TestPath_Create(t *testing.T) { NumUses: 12, BoundCIDRs: []*sockaddr.SockAddrMarshaler{{SockAddr: expectedSockAddr}}, AllowedRedirectURIs: []string(nil), + CallbackMode: "client", MaxAge: 60 * time.Second, } @@ -564,6 +565,7 @@ func TestPath_OIDCCreate(t *testing.T) { "bar": "baz", }, AllowedRedirectURIs: []string{"https://example.com", "http://localhost:8250"}, + CallbackMode: "client", ClaimMappings: map[string]string{ "foo": "a", "bar": "b", @@ -770,6 +772,7 @@ func TestPath_Read(t *testing.T) { "bound_subject": "testsub", "bound_audiences": []string{"vault"}, "allowed_redirect_uris": []string{"http://127.0.0.1"}, + "callback_mode": "client", "oidc_scopes": []string{"email", "profile"}, "user_claim": "user", "user_claim_json_pointer": false,