Skip to content

Commit

Permalink
Add oidc callback mode that is direct to server
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com>
  • Loading branch information
DrDaveD committed Jun 7, 2024
1 parent 5ac7393 commit 80e881b
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 99 deletions.
1 change: 1 addition & 0 deletions builtin/credential/jwt/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func backend() *jwtAuthBackend {
"login",
"oidc/auth_url",
"oidc/callback",
"oidc/poll",

// Uncomment to mount simple UI handler for local development
// "ui",
Expand Down
135 changes: 113 additions & 22 deletions builtin/credential/jwt/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path"
Expand All @@ -27,9 +28,11 @@ const (
defaultPort = "8250"
defaultCallbackHost = "localhost"
defaultCallbackMethod = "http"
defaultCallbackMode = "client"

FieldCallbackHost = "callbackhost"
FieldCallbackMethod = "callbackmethod"
FieldCallbackMode = "callbackmode"
FieldListenAddress = "listenaddress"
FieldPort = "port"
FieldCallbackPort = "callbackport"
Expand Down Expand Up @@ -69,19 +72,44 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo
port = defaultPort
}

var serverURL *url.URL
callbackMode, ok := m[FieldCallbackMode]
if !ok || callbackMode == "" {
callbackMode = defaultCallbackMode
} else if callbackMode == "direct" {
serverAddr := api.ReadBaoVariable("BAO_ADDR")
if serverAddr != "" {
serverURL, _ = url.Parse(serverAddr)
}
}

callbackHost, ok := m[FieldCallbackHost]
if !ok {
callbackHost = defaultCallbackHost
if serverURL != nil {
callbackHost = serverURL.Hostname()
} else {
// Note that since defaultCallbackHost is localhost,
// this only works if the cli is run on the server
callbackHost = defaultCallbackHost
}
}

callbackMethod, ok := m[FieldCallbackMethod]
if !ok {
callbackMethod = defaultCallbackMethod
if serverURL != nil {
callbackMethod = serverURL.Scheme
} else {
callbackMethod = defaultCallbackMethod
}
}

callbackPort, ok := m[FieldCallbackPort]
if !ok {
callbackPort = port
if serverURL != nil {
callbackPort = serverURL.Port() + "/v1/auth/" + mount
} else {
callbackPort = port
}
}

parseBool := func(f string, d bool) (bool, error) {
Expand Down Expand Up @@ -115,20 +143,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo

role := m["role"]

authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
if err != nil {
return nil, err
}

// Set up callback handler
doneCh := make(chan loginResp)
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
var pollInterval string
var interval int
var state string
var listener net.Listener

if secret != nil {
pollInterval, _ = secret.Data["poll_interval"].(string)
state, _ = secret.Data["state"].(string)
}
if callbackMode == "direct" {
if state == "" {
return nil, errors.New("no state returned in direct callback mode")
}
if pollInterval == "" {
return nil, errors.New("no poll_interval returned in direct callback mode")
}
interval, err = strconv.Atoi(pollInterval)
if err != nil {
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
}
} else {
if state != "" {
return nil, errors.New("state returned in client callback mode, try direct")
}
if pollInterval != "" {
return nil, errors.New("poll_interval returned in client callback mode")
}
// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err = net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
}
defer listener.Close()
}
defer listener.Close()

// Open the default browser to the callback URL.
if !skipBrowserLaunch {
Expand All @@ -144,6 +201,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo
}
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")

if callbackMode == "direct" {
data := map[string]interface{}{
"state": state,
"client_nonce": clientNonce,
}
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
for {
time.Sleep(time.Duration(interval) * time.Second)

secret, err := c.Logical().Write(pollUrl, data)
if err == nil {
return secret, nil
}
if strings.HasSuffix(err.Error(), "slow_down") {
interval *= 2
} else if !strings.HasSuffix(err.Error(), "authorization_pending") {
return nil, err
}
// authorization is pending, try again
}
}

// Start local server
go func() {
err := http.Serve(listener, nil)
Expand Down Expand Up @@ -210,12 +289,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
}
}

func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
var authURL string

clientNonce, err := base62.Random(20)
if err != nil {
return "", "", err
return "", "", nil, err
}

redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
Expand All @@ -227,18 +306,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho

secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
if err != nil {
return "", "", err
return "", "", nil, err
}

if secret != nil {
authURL = secret.Data["auth_url"].(string)
}

if authURL == "" {
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check OpenBao logs for more information.", role, redirectURI)
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check OpenBao logs for more information.", role, redirectURI)
}

return authURL, clientNonce, nil
return authURL, clientNonce, secret, nil
}

// parseError converts error from the API into summary and detailed portions.
Expand Down Expand Up @@ -292,35 +371,47 @@ Usage: bao login -method=oidc [CONFIG K=V...]
https://accounts.google.com/o/oauth2/v2/...
The default browser will be opened for the user to complete the login. Alternatively,
the user may visit the provided URL directly.
The default browser will be opened for the user to complete the login.
Alternatively, the user may visit the provided URL directly.
Configuration:
role=<string>
OpenBao role of type "OIDC" to use for authentication.
%s=<string>
Optional address to bind the OIDC callback listener to (default: localhost).
Mode of callback: "direct" for direct connection to the server or "client"
for connection to the command line client (default: client).
%s=<string>
Optional address to bind the OIDC callback listener to in client callback
mode (default: localhost).
%s=<string>
Optional localhost port to use for OIDC callback (default: 8250).
Optional localhost port to use for OIDC callback in client callback mode
(default: 8250).
%s=<string>
Optional method to to use in OIDC redirect_uri (default: http).
Optional method to use in OIDC redirect_uri (default: the method from
$BAO_ADDR or $VAULT_ADDR in direct callback mode, else http)
%s=<string>
Optional callback host address to use in OIDC redirect_uri (default: localhost).
Optional callback host address to use in OIDC redirect_uri (default:
the host from $BAO_ADDR or $VAULT_ADDR in direct callback mode, else
localhost).
%s=<string>
Optional port to to use in OIDC redirect_uri (default: the value set for port).
Optional port to use in OIDC redirect_uri (default: the value set for
port in client callback mode, else the port from $BAO_ADDR or $VAULT_ADDR
with an added /v1/auth/<path> where <path> is from the login -path option).
%s=<bool>
Toggle the automatic launching of the default browser to the login URL. (default: false).
%s=<bool>
Abort on any error. (default: false).
`,
FieldCallbackMode,
FieldListenAddress, FieldPort, FieldCallbackMethod,
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
FieldAbortOnError,
Expand Down
File renamed without changes.
Loading

0 comments on commit 80e881b

Please sign in to comment.