-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial oauth implementation #122
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b63e445
Initial oauth implementation
rcypher-databricks 0127175
Moved default authenticator into internal package.
rcypher-databricks 5579ce8
Refactoring out of internal and dsn parsing
rcypher-databricks ac38a7e
Updated doc.go
rcypher-databricks 57a3105
Removed unused constants
rcypher-databricks 74df4bf
Removed unused function
rcypher-databricks File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,40 @@ | ||
package auth | ||
|
||
import "net/http" | ||
import ( | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
type Authenticator interface { | ||
Authenticate(*http.Request) error | ||
} | ||
|
||
type AuthType int | ||
|
||
const ( | ||
AuthTypeUnknown AuthType = iota | ||
AuthTypePat | ||
AuthTypeOauthU2M | ||
AuthTypeOauthM2M | ||
) | ||
|
||
var authTypeNames []string = []string{"Unknown", "Pat", "OauthU2M", "OauthM2M"} | ||
|
||
func (at AuthType) String() string { | ||
if at >= 0 && int(at) < len(authTypeNames) { | ||
return authTypeNames[at] | ||
} | ||
|
||
return authTypeNames[0] | ||
} | ||
|
||
func ParseAuthType(typeString string) AuthType { | ||
typeString = strings.ToLower(typeString) | ||
for i, n := range authTypeNames { | ||
if strings.ToLower(n) == typeString { | ||
return AuthType(i) | ||
} | ||
} | ||
|
||
return AuthTypeUnknown | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package m2m | ||
|
||
// clientid e92aa085-4875-42fe-ad75-ba38fb3c9706 | ||
// secretid vUdzecmn4aUi2jRDamaBOy3qThu9LSgeV_BW4UnQ | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"sync" | ||
|
||
"github.com/databricks/databricks-sql-go/auth" | ||
"github.com/databricks/databricks-sql-go/auth/oauth" | ||
"github.com/rs/zerolog/log" | ||
"golang.org/x/oauth2" | ||
"golang.org/x/oauth2/clientcredentials" | ||
) | ||
|
||
func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator { | ||
scopes := oauth.GetScopes(hostName, []string{}) | ||
return &authClient{ | ||
clientID: clientID, | ||
clientSecret: clientSecret, | ||
hostName: hostName, | ||
scopes: scopes, | ||
} | ||
} | ||
|
||
type authClient struct { | ||
clientID string | ||
clientSecret string | ||
hostName string | ||
scopes []string | ||
tokenSource oauth2.TokenSource | ||
mx sync.Mutex | ||
} | ||
|
||
// Auth will start the OAuth Authorization Flow to authenticate the cli client | ||
// using the users credentials in the browser. Compatible with SSO. | ||
func (c *authClient) Authenticate(r *http.Request) error { | ||
c.mx.Lock() | ||
defer c.mx.Unlock() | ||
if c.tokenSource != nil { | ||
token, err := c.tokenSource.Token() | ||
if err != nil { | ||
return err | ||
} | ||
token.SetAuthHeader(r) | ||
return nil | ||
} | ||
|
||
config, err := GetConfig(context.Background(), c.hostName, c.clientID, c.clientSecret, c.scopes) | ||
if err != nil { | ||
return fmt.Errorf("unable to generate clientCredentials.Config: %w", err) | ||
} | ||
|
||
c.tokenSource = GetTokenSource(config) | ||
token, err := c.tokenSource.Token() | ||
log.Info().Msgf("token fetched successfully") | ||
if err != nil { | ||
log.Err(err).Msg("failed to get token") | ||
|
||
return err | ||
} | ||
token.SetAuthHeader(r) | ||
|
||
return nil | ||
|
||
} | ||
|
||
func GetTokenSource(config clientcredentials.Config) oauth2.TokenSource { | ||
tokenSource := config.TokenSource(context.Background()) | ||
return tokenSource | ||
} | ||
|
||
func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, scopes []string) (clientcredentials.Config, error) { | ||
// Get the endpoint based on the host name | ||
endpoint, err := oauth.GetEndpoint(ctx, issuerURL) | ||
if err != nil { | ||
return clientcredentials.Config{}, fmt.Errorf("could not lookup provider details: %w", err) | ||
} | ||
|
||
config := clientcredentials.Config{ | ||
ClientID: clientID, | ||
ClientSecret: clientSecret, | ||
TokenURL: endpoint.TokenURL, | ||
Scopes: scopes, | ||
} | ||
|
||
return config, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package oauth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/coreos/go-oidc/v3/oidc" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
const ( | ||
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68" | ||
) | ||
|
||
func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) { | ||
if ctx == nil { | ||
ctx = context.Background() | ||
} | ||
|
||
cloud := InferCloudFromHost(hostName) | ||
|
||
if cloud == Unknown { | ||
return oauth2.Endpoint{}, errors.New("unsupported cloud type") | ||
} | ||
|
||
if cloud == Azure { | ||
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName) | ||
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName) | ||
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil | ||
} | ||
|
||
issuerURL := fmt.Sprintf("https://%s/oidc", hostName) | ||
ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL) | ||
provider, err := oidc.NewProvider(ctx, issuerURL) | ||
if err != nil { | ||
return oauth2.Endpoint{}, err | ||
} | ||
|
||
endpoint := provider.Endpoint() | ||
|
||
return endpoint, err | ||
} | ||
|
||
func GetScopes(hostName string, scopes []string) []string { | ||
for _, s := range []string{oidc.ScopeOfflineAccess} { | ||
if !hasScope(scopes, s) { | ||
scopes = append(scopes, s) | ||
} | ||
} | ||
|
||
cloudType := InferCloudFromHost(hostName) | ||
if cloudType == Azure { | ||
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId) | ||
if !hasScope(scopes, userImpersonationScope) { | ||
scopes = append(scopes, userImpersonationScope) | ||
} | ||
} else { | ||
if !hasScope(scopes, "sql") { | ||
scopes = append(scopes, "sql") | ||
} | ||
} | ||
|
||
return scopes | ||
} | ||
|
||
func hasScope(scopes []string, scope string) bool { | ||
for _, s := range scopes { | ||
if s == scope { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
var databricksAWSDomains []string = []string{ | ||
".cloud.databricks.com", | ||
".dev.databricks.com", | ||
} | ||
|
||
var databricksAzureDomains []string = []string{ | ||
".azuredatabricks.net", | ||
".databricks.azure.cn", | ||
".databricks.azure.us", | ||
} | ||
|
||
type CloudType int | ||
|
||
const ( | ||
AWS = iota | ||
Azure | ||
Unknown | ||
) | ||
|
||
func (cl CloudType) String() string { | ||
switch cl { | ||
case AWS: | ||
return "AWS" | ||
case Azure: | ||
return "Azure" | ||
} | ||
|
||
return "Unknown" | ||
} | ||
|
||
func InferCloudFromHost(hostname string) CloudType { | ||
|
||
for _, d := range databricksAzureDomains { | ||
if strings.Contains(hostname, d) { | ||
return Azure | ||
} | ||
} | ||
|
||
for _, d := range databricksAWSDomains { | ||
if strings.Contains(hostname, d) { | ||
return AWS | ||
} | ||
} | ||
|
||
return Unknown | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package pkce | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/sha256" | ||
"encoding/base64" | ||
"encoding/hex" | ||
"fmt" | ||
"io" | ||
|
||
"golang.org/x/oauth2" | ||
) | ||
|
||
// Generate generates a new random PKCE code. | ||
func Generate() (Code, error) { return generate(rand.Reader) } | ||
|
||
func generate(rand io.Reader) (Code, error) { | ||
// From https://tools.ietf.org/html/rfc7636#section-4.1: | ||
// code_verifier = high-entropy cryptographic random STRING using the | ||
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" | ||
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters | ||
// and a maximum length of 128 characters. | ||
var buf [32]byte | ||
if _, err := io.ReadFull(rand, buf[:]); err != nil { | ||
return "", fmt.Errorf("could not generate PKCE code: %w", err) | ||
} | ||
return Code(hex.EncodeToString(buf[:])), nil | ||
} | ||
|
||
// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE). | ||
type Code string | ||
|
||
// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge. | ||
func (p *Code) Challenge() oauth2.AuthCodeOption { | ||
b := sha256.Sum256([]byte(*p)) | ||
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:])) | ||
} | ||
|
||
// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method. | ||
func (p *Code) Method() oauth2.AuthCodeOption { | ||
return oauth2.SetAuthURLParam("code_challenge_method", "S256") | ||
} | ||
|
||
// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier. | ||
func (p *Code) Verifier() oauth2.AuthCodeOption { | ||
return oauth2.SetAuthURLParam("code_verifier", string(*p)) | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use sql or all-apis? Gotta check with Bryan Mcquade