Skip to content

Commit

Permalink
Initial oauth implementation (#122)
Browse files Browse the repository at this point in the history
Oauth implementation.
Added m2m authenticator.
Added basic u2m authenticator. Implementation is broken out into public
functions that clients can use to implement their own authenticator.
Allow specifying auth type and parameters in dsn.
  • Loading branch information
rcypher-databricks authored Sep 5, 2023
2 parents 7a177c9 + 74df4bf commit 34599c4
Show file tree
Hide file tree
Showing 17 changed files with 1,388 additions and 275 deletions.
35 changes: 34 additions & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
package auth

import "net/http"
import (
"net/http"
"strings"
)

type Authenticator interface {
Authenticate(*http.Request) error
}

type AuthType int

const (
AuthTypeUnknown AuthType = iota
AuthTypePat
AuthTypeOauthU2M
AuthTypeOauthM2M
)

var authTypeNames []string = []string{"Unknown", "Pat", "OauthU2M", "OauthM2M"}

func (at AuthType) String() string {
if at >= 0 && int(at) < len(authTypeNames) {
return authTypeNames[at]
}

return authTypeNames[0]
}

func ParseAuthType(typeString string) AuthType {
typeString = strings.ToLower(typeString)
for i, n := range authTypeNames {
if strings.ToLower(n) == typeString {
return AuthType(i)
}
}

return AuthTypeUnknown
}
91 changes: 91 additions & 0 deletions auth/oauth/m2m/m2m.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package m2m

// clientid e92aa085-4875-42fe-ad75-ba38fb3c9706
// secretid vUdzecmn4aUi2jRDamaBOy3qThu9LSgeV_BW4UnQ

import (
"context"
"fmt"
"net/http"
"sync"

"github.com/databricks/databricks-sql-go/auth"
"github.com/databricks/databricks-sql-go/auth/oauth"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator {
scopes := oauth.GetScopes(hostName, []string{})
return &authClient{
clientID: clientID,
clientSecret: clientSecret,
hostName: hostName,
scopes: scopes,
}
}

type authClient struct {
clientID string
clientSecret string
hostName string
scopes []string
tokenSource oauth2.TokenSource
mx sync.Mutex
}

// Auth will start the OAuth Authorization Flow to authenticate the cli client
// using the users credentials in the browser. Compatible with SSO.
func (c *authClient) Authenticate(r *http.Request) error {
c.mx.Lock()
defer c.mx.Unlock()
if c.tokenSource != nil {
token, err := c.tokenSource.Token()
if err != nil {
return err
}
token.SetAuthHeader(r)
return nil
}

config, err := GetConfig(context.Background(), c.hostName, c.clientID, c.clientSecret, c.scopes)
if err != nil {
return fmt.Errorf("unable to generate clientCredentials.Config: %w", err)
}

c.tokenSource = GetTokenSource(config)
token, err := c.tokenSource.Token()
log.Info().Msgf("token fetched successfully")
if err != nil {
log.Err(err).Msg("failed to get token")

return err
}
token.SetAuthHeader(r)

return nil

}

func GetTokenSource(config clientcredentials.Config) oauth2.TokenSource {
tokenSource := config.TokenSource(context.Background())
return tokenSource
}

func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, scopes []string) (clientcredentials.Config, error) {
// Get the endpoint based on the host name
endpoint, err := oauth.GetEndpoint(ctx, issuerURL)
if err != nil {
return clientcredentials.Config{}, fmt.Errorf("could not lookup provider details: %w", err)
}

config := clientcredentials.Config{
ClientID: clientID,
ClientSecret: clientSecret,
TokenURL: endpoint.TokenURL,
Scopes: scopes,
}

return config, nil
}
122 changes: 122 additions & 0 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package oauth

import (
"context"
"errors"
"fmt"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)

const (
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68"
)

func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) {
if ctx == nil {
ctx = context.Background()
}

cloud := InferCloudFromHost(hostName)

if cloud == Unknown {
return oauth2.Endpoint{}, errors.New("unsupported cloud type")
}

if cloud == Azure {
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName)
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName)
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil
}

issuerURL := fmt.Sprintf("https://%s/oidc", hostName)
ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL)
provider, err := oidc.NewProvider(ctx, issuerURL)
if err != nil {
return oauth2.Endpoint{}, err
}

endpoint := provider.Endpoint()

return endpoint, err
}

func GetScopes(hostName string, scopes []string) []string {
for _, s := range []string{oidc.ScopeOfflineAccess} {
if !hasScope(scopes, s) {
scopes = append(scopes, s)
}
}

cloudType := InferCloudFromHost(hostName)
if cloudType == Azure {
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
if !hasScope(scopes, userImpersonationScope) {
scopes = append(scopes, userImpersonationScope)
}
} else {
if !hasScope(scopes, "sql") {
scopes = append(scopes, "sql")
}
}

return scopes
}

func hasScope(scopes []string, scope string) bool {
for _, s := range scopes {
if s == scope {
return true
}
}
return false
}

var databricksAWSDomains []string = []string{
".cloud.databricks.com",
".dev.databricks.com",
}

var databricksAzureDomains []string = []string{
".azuredatabricks.net",
".databricks.azure.cn",
".databricks.azure.us",
}

type CloudType int

const (
AWS = iota
Azure
Unknown
)

func (cl CloudType) String() string {
switch cl {
case AWS:
return "AWS"
case Azure:
return "Azure"
}

return "Unknown"
}

func InferCloudFromHost(hostname string) CloudType {

for _, d := range databricksAzureDomains {
if strings.Contains(hostname, d) {
return Azure
}
}

for _, d := range databricksAWSDomains {
if strings.Contains(hostname, d) {
return AWS
}
}

return Unknown
}
47 changes: 47 additions & 0 deletions auth/oauth/pkce/pkce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package pkce

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"io"

"golang.org/x/oauth2"
)

// Generate generates a new random PKCE code.
func Generate() (Code, error) { return generate(rand.Reader) }

func generate(rand io.Reader) (Code, error) {
// From https://tools.ietf.org/html/rfc7636#section-4.1:
// code_verifier = high-entropy cryptographic random STRING using the
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters
// and a maximum length of 128 characters.
var buf [32]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", fmt.Errorf("could not generate PKCE code: %w", err)
}
return Code(hex.EncodeToString(buf[:])), nil
}

// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE).
type Code string

// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge.
func (p *Code) Challenge() oauth2.AuthCodeOption {
b := sha256.Sum256([]byte(*p))
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:]))
}

// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method.
func (p *Code) Method() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_challenge_method", "S256")
}

// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier.
func (p *Code) Verifier() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_verifier", string(*p))
}
Loading

0 comments on commit 34599c4

Please sign in to comment.