-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Oauth implementation. Added m2m authenticator. Added basic u2m authenticator. Implementation is broken out into public functions that clients can use to implement their own authenticator. Allow specifying auth type and parameters in dsn.
- Loading branch information
Showing
17 changed files
with
1,388 additions
and
275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,40 @@ | ||
package auth | ||
|
||
import "net/http" | ||
import ( | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
type Authenticator interface { | ||
Authenticate(*http.Request) error | ||
} | ||
|
||
type AuthType int | ||
|
||
const ( | ||
AuthTypeUnknown AuthType = iota | ||
AuthTypePat | ||
AuthTypeOauthU2M | ||
AuthTypeOauthM2M | ||
) | ||
|
||
var authTypeNames []string = []string{"Unknown", "Pat", "OauthU2M", "OauthM2M"} | ||
|
||
func (at AuthType) String() string { | ||
if at >= 0 && int(at) < len(authTypeNames) { | ||
return authTypeNames[at] | ||
} | ||
|
||
return authTypeNames[0] | ||
} | ||
|
||
func ParseAuthType(typeString string) AuthType { | ||
typeString = strings.ToLower(typeString) | ||
for i, n := range authTypeNames { | ||
if strings.ToLower(n) == typeString { | ||
return AuthType(i) | ||
} | ||
} | ||
|
||
return AuthTypeUnknown | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package m2m | ||
|
||
// clientid e92aa085-4875-42fe-ad75-ba38fb3c9706 | ||
// secretid vUdzecmn4aUi2jRDamaBOy3qThu9LSgeV_BW4UnQ | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"sync" | ||
|
||
"github.com/databricks/databricks-sql-go/auth" | ||
"github.com/databricks/databricks-sql-go/auth/oauth" | ||
"github.com/rs/zerolog/log" | ||
"golang.org/x/oauth2" | ||
"golang.org/x/oauth2/clientcredentials" | ||
) | ||
|
||
func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator { | ||
scopes := oauth.GetScopes(hostName, []string{}) | ||
return &authClient{ | ||
clientID: clientID, | ||
clientSecret: clientSecret, | ||
hostName: hostName, | ||
scopes: scopes, | ||
} | ||
} | ||
|
||
type authClient struct { | ||
clientID string | ||
clientSecret string | ||
hostName string | ||
scopes []string | ||
tokenSource oauth2.TokenSource | ||
mx sync.Mutex | ||
} | ||
|
||
// Auth will start the OAuth Authorization Flow to authenticate the cli client | ||
// using the users credentials in the browser. Compatible with SSO. | ||
func (c *authClient) Authenticate(r *http.Request) error { | ||
c.mx.Lock() | ||
defer c.mx.Unlock() | ||
if c.tokenSource != nil { | ||
token, err := c.tokenSource.Token() | ||
if err != nil { | ||
return err | ||
} | ||
token.SetAuthHeader(r) | ||
return nil | ||
} | ||
|
||
config, err := GetConfig(context.Background(), c.hostName, c.clientID, c.clientSecret, c.scopes) | ||
if err != nil { | ||
return fmt.Errorf("unable to generate clientCredentials.Config: %w", err) | ||
} | ||
|
||
c.tokenSource = GetTokenSource(config) | ||
token, err := c.tokenSource.Token() | ||
log.Info().Msgf("token fetched successfully") | ||
if err != nil { | ||
log.Err(err).Msg("failed to get token") | ||
|
||
return err | ||
} | ||
token.SetAuthHeader(r) | ||
|
||
return nil | ||
|
||
} | ||
|
||
func GetTokenSource(config clientcredentials.Config) oauth2.TokenSource { | ||
tokenSource := config.TokenSource(context.Background()) | ||
return tokenSource | ||
} | ||
|
||
func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, scopes []string) (clientcredentials.Config, error) { | ||
// Get the endpoint based on the host name | ||
endpoint, err := oauth.GetEndpoint(ctx, issuerURL) | ||
if err != nil { | ||
return clientcredentials.Config{}, fmt.Errorf("could not lookup provider details: %w", err) | ||
} | ||
|
||
config := clientcredentials.Config{ | ||
ClientID: clientID, | ||
ClientSecret: clientSecret, | ||
TokenURL: endpoint.TokenURL, | ||
Scopes: scopes, | ||
} | ||
|
||
return config, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package oauth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/coreos/go-oidc/v3/oidc" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
const ( | ||
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68" | ||
) | ||
|
||
func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) { | ||
if ctx == nil { | ||
ctx = context.Background() | ||
} | ||
|
||
cloud := InferCloudFromHost(hostName) | ||
|
||
if cloud == Unknown { | ||
return oauth2.Endpoint{}, errors.New("unsupported cloud type") | ||
} | ||
|
||
if cloud == Azure { | ||
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName) | ||
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName) | ||
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil | ||
} | ||
|
||
issuerURL := fmt.Sprintf("https://%s/oidc", hostName) | ||
ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL) | ||
provider, err := oidc.NewProvider(ctx, issuerURL) | ||
if err != nil { | ||
return oauth2.Endpoint{}, err | ||
} | ||
|
||
endpoint := provider.Endpoint() | ||
|
||
return endpoint, err | ||
} | ||
|
||
func GetScopes(hostName string, scopes []string) []string { | ||
for _, s := range []string{oidc.ScopeOfflineAccess} { | ||
if !hasScope(scopes, s) { | ||
scopes = append(scopes, s) | ||
} | ||
} | ||
|
||
cloudType := InferCloudFromHost(hostName) | ||
if cloudType == Azure { | ||
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId) | ||
if !hasScope(scopes, userImpersonationScope) { | ||
scopes = append(scopes, userImpersonationScope) | ||
} | ||
} else { | ||
if !hasScope(scopes, "sql") { | ||
scopes = append(scopes, "sql") | ||
} | ||
} | ||
|
||
return scopes | ||
} | ||
|
||
func hasScope(scopes []string, scope string) bool { | ||
for _, s := range scopes { | ||
if s == scope { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
var databricksAWSDomains []string = []string{ | ||
".cloud.databricks.com", | ||
".dev.databricks.com", | ||
} | ||
|
||
var databricksAzureDomains []string = []string{ | ||
".azuredatabricks.net", | ||
".databricks.azure.cn", | ||
".databricks.azure.us", | ||
} | ||
|
||
type CloudType int | ||
|
||
const ( | ||
AWS = iota | ||
Azure | ||
Unknown | ||
) | ||
|
||
func (cl CloudType) String() string { | ||
switch cl { | ||
case AWS: | ||
return "AWS" | ||
case Azure: | ||
return "Azure" | ||
} | ||
|
||
return "Unknown" | ||
} | ||
|
||
func InferCloudFromHost(hostname string) CloudType { | ||
|
||
for _, d := range databricksAzureDomains { | ||
if strings.Contains(hostname, d) { | ||
return Azure | ||
} | ||
} | ||
|
||
for _, d := range databricksAWSDomains { | ||
if strings.Contains(hostname, d) { | ||
return AWS | ||
} | ||
} | ||
|
||
return Unknown | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package pkce | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/sha256" | ||
"encoding/base64" | ||
"encoding/hex" | ||
"fmt" | ||
"io" | ||
|
||
"golang.org/x/oauth2" | ||
) | ||
|
||
// Generate generates a new random PKCE code. | ||
func Generate() (Code, error) { return generate(rand.Reader) } | ||
|
||
func generate(rand io.Reader) (Code, error) { | ||
// From https://tools.ietf.org/html/rfc7636#section-4.1: | ||
// code_verifier = high-entropy cryptographic random STRING using the | ||
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" | ||
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters | ||
// and a maximum length of 128 characters. | ||
var buf [32]byte | ||
if _, err := io.ReadFull(rand, buf[:]); err != nil { | ||
return "", fmt.Errorf("could not generate PKCE code: %w", err) | ||
} | ||
return Code(hex.EncodeToString(buf[:])), nil | ||
} | ||
|
||
// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE). | ||
type Code string | ||
|
||
// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge. | ||
func (p *Code) Challenge() oauth2.AuthCodeOption { | ||
b := sha256.Sum256([]byte(*p)) | ||
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:])) | ||
} | ||
|
||
// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method. | ||
func (p *Code) Method() oauth2.AuthCodeOption { | ||
return oauth2.SetAuthURLParam("code_challenge_method", "S256") | ||
} | ||
|
||
// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier. | ||
func (p *Code) Verifier() oauth2.AuthCodeOption { | ||
return oauth2.SetAuthURLParam("code_verifier", string(*p)) | ||
} |
Oops, something went wrong.