Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial oauth implementation #122

Merged
merged 6 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
package auth

import "net/http"
import (
"net/http"
"strings"
)

type Authenticator interface {
Authenticate(*http.Request) error
}

type AuthType int

const (
AuthTypeUnknown AuthType = iota
AuthTypePat
AuthTypeOauthU2M
AuthTypeOauthM2M
)

var authTypeNames []string = []string{"Unknown", "Pat", "OauthU2M", "OauthM2M"}

func (at AuthType) String() string {
if at >= 0 && int(at) < len(authTypeNames) {
return authTypeNames[at]
}

return authTypeNames[0]
}

func ParseAuthType(typeString string) AuthType {
typeString = strings.ToLower(typeString)
for i, n := range authTypeNames {
if strings.ToLower(n) == typeString {
return AuthType(i)
}
}

return AuthTypeUnknown
}
91 changes: 91 additions & 0 deletions auth/oauth/m2m/m2m.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package m2m

// clientid e92aa085-4875-42fe-ad75-ba38fb3c9706
// secretid vUdzecmn4aUi2jRDamaBOy3qThu9LSgeV_BW4UnQ

import (
"context"
"fmt"
"net/http"
"sync"

"github.com/databricks/databricks-sql-go/auth"
"github.com/databricks/databricks-sql-go/auth/oauth"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator {
scopes := oauth.GetScopes(hostName, []string{})
return &authClient{
clientID: clientID,
clientSecret: clientSecret,
hostName: hostName,
scopes: scopes,
}
}

type authClient struct {
clientID string
clientSecret string
hostName string
scopes []string
tokenSource oauth2.TokenSource
mx sync.Mutex
}

// Auth will start the OAuth Authorization Flow to authenticate the cli client
// using the users credentials in the browser. Compatible with SSO.
func (c *authClient) Authenticate(r *http.Request) error {
c.mx.Lock()
defer c.mx.Unlock()
if c.tokenSource != nil {
token, err := c.tokenSource.Token()
if err != nil {
return err
}
token.SetAuthHeader(r)
return nil
}

config, err := GetConfig(context.Background(), c.hostName, c.clientID, c.clientSecret, c.scopes)
if err != nil {
return fmt.Errorf("unable to generate clientCredentials.Config: %w", err)
}

c.tokenSource = GetTokenSource(config)
token, err := c.tokenSource.Token()
log.Info().Msgf("token fetched successfully")
if err != nil {
log.Err(err).Msg("failed to get token")

return err
}
token.SetAuthHeader(r)

return nil

}

func GetTokenSource(config clientcredentials.Config) oauth2.TokenSource {
tokenSource := config.TokenSource(context.Background())
return tokenSource
}

func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, scopes []string) (clientcredentials.Config, error) {
// Get the endpoint based on the host name
endpoint, err := oauth.GetEndpoint(ctx, issuerURL)
if err != nil {
return clientcredentials.Config{}, fmt.Errorf("could not lookup provider details: %w", err)
}

config := clientcredentials.Config{
ClientID: clientID,
ClientSecret: clientSecret,
TokenURL: endpoint.TokenURL,
Scopes: scopes,
}

return config, nil
}
122 changes: 122 additions & 0 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package oauth

import (
"context"
"errors"
"fmt"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)

const (
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68"
)

func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) {
if ctx == nil {
ctx = context.Background()
}

cloud := InferCloudFromHost(hostName)

if cloud == Unknown {
return oauth2.Endpoint{}, errors.New("unsupported cloud type")
}

if cloud == Azure {
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName)
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName)
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil
}

issuerURL := fmt.Sprintf("https://%s/oidc", hostName)
ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL)
provider, err := oidc.NewProvider(ctx, issuerURL)
if err != nil {
return oauth2.Endpoint{}, err
}

endpoint := provider.Endpoint()

return endpoint, err
}

func GetScopes(hostName string, scopes []string) []string {
for _, s := range []string{oidc.ScopeOfflineAccess} {
if !hasScope(scopes, s) {
scopes = append(scopes, s)
}
}

cloudType := InferCloudFromHost(hostName)
if cloudType == Azure {
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
if !hasScope(scopes, userImpersonationScope) {
scopes = append(scopes, userImpersonationScope)
}
} else {
if !hasScope(scopes, "sql") {
scopes = append(scopes, "sql")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use sql or all-apis? Gotta check with Bryan Mcquade

}
}

return scopes
}

func hasScope(scopes []string, scope string) bool {
for _, s := range scopes {
if s == scope {
return true
}
}
return false
}

var databricksAWSDomains []string = []string{
".cloud.databricks.com",
".dev.databricks.com",
}

var databricksAzureDomains []string = []string{
".azuredatabricks.net",
".databricks.azure.cn",
".databricks.azure.us",
}

type CloudType int

const (
AWS = iota
Azure
Unknown
)

func (cl CloudType) String() string {
switch cl {
case AWS:
return "AWS"
case Azure:
return "Azure"
}

return "Unknown"
}

func InferCloudFromHost(hostname string) CloudType {

for _, d := range databricksAzureDomains {
if strings.Contains(hostname, d) {
return Azure
}
}

for _, d := range databricksAWSDomains {
if strings.Contains(hostname, d) {
return AWS
}
}

return Unknown
}
47 changes: 47 additions & 0 deletions auth/oauth/pkce/pkce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package pkce

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"io"

"golang.org/x/oauth2"
)

// Generate generates a new random PKCE code.
func Generate() (Code, error) { return generate(rand.Reader) }

func generate(rand io.Reader) (Code, error) {
// From https://tools.ietf.org/html/rfc7636#section-4.1:
// code_verifier = high-entropy cryptographic random STRING using the
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters
// and a maximum length of 128 characters.
var buf [32]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", fmt.Errorf("could not generate PKCE code: %w", err)
}
return Code(hex.EncodeToString(buf[:])), nil
}

// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE).
type Code string

// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge.
func (p *Code) Challenge() oauth2.AuthCodeOption {
b := sha256.Sum256([]byte(*p))
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:]))
}

// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method.
func (p *Code) Method() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_challenge_method", "S256")
}

// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier.
func (p *Code) Verifier() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_verifier", string(*p))
}
Loading
Loading