Skip to content

Commit

Permalink
SDK-3949: Introduce sending client information with requests (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
ewanharris authored Feb 13, 2023
1 parent 67001c2 commit aa9273b
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 13 deletions.
55 changes: 55 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,34 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"encoding/base64"
"encoding/json"

"github.com/auth0/go-auth0"
)

// UserAgent is the default user agent string.
var UserAgent = fmt.Sprintf("Go-Auth0-SDK/%s", auth0.Version)

// Auth0ClientInfo is the structure used to send client information in the "Auth0-Client" header.
type Auth0ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Env map[string]string `json:"env,omitempty"`
}

// IsEmpty checks whether the provided Auth0ClientInfo data is nil or has no data to allow
// short-circuiting the "Auth0-Client" header configuration.
func (td *Auth0ClientInfo) IsEmpty() bool {
if td == nil {
return true
}
return td.Name == "" && td.Version == "" && len(td.Env) == 0
}

// DefaultAuth0ClientInfo is the default client information sent by the go-auth0 SDK.
var DefaultAuth0ClientInfo = &Auth0ClientInfo{Name: "go-auth0", Version: auth0.Version}

// RoundTripFunc is an adapter to allow the use of ordinary functions as HTTP
// round trips.
type RoundTripFunc func(*http.Request) (*http.Response, error)
Expand Down Expand Up @@ -69,6 +91,25 @@ func UserAgentTransport(base http.RoundTripper, userAgent string) http.RoundTrip
})
}

// Auth0ClientInfoTransport wraps base transport with a customized "Auth0-Client" header.
func Auth0ClientInfoTransport(base http.RoundTripper, auth0ClientInfo *Auth0ClientInfo) (http.RoundTripper, error) {
if base == nil {
base = http.DefaultTransport
}

auth0ClientJson, err := json.Marshal(auth0ClientInfo)
if err != nil {
return nil, err
}

auth0ClientEncoded := base64.StdEncoding.EncodeToString(auth0ClientJson)

return RoundTripFunc(func(req *http.Request) (*http.Response, error) {
req.Header.Set("Auth0-Client", auth0ClientEncoded)
return base.RoundTrip(req)
}), nil
}

func dumpRequest(r *http.Request) {
b, _ := httputil.DumpRequestOut(r, true)
log.Printf("\n%s\n", b)
Expand Down Expand Up @@ -123,6 +164,20 @@ func WithUserAgent(userAgent string) Option {
}
}

// WithAuth0ClientInfo configures the client to overwrite the "Auth0-Client" header.
func WithAuth0ClientInfo(auth0ClientInfo *Auth0ClientInfo) Option {
return func(c *http.Client) {
if auth0ClientInfo.IsEmpty() {
return
}
transport, err := Auth0ClientInfoTransport(c.Transport, auth0ClientInfo)
if err != nil {
return
}
c.Transport = transport
}
}

// Wrap the base client with transports that enable OAuth2 authentication.
func Wrap(base *http.Client, tokenSource oauth2.TokenSource, options ...Option) *http.Client {
if base == nil {
Expand Down
42 changes: 42 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,45 @@ func TestOAuth2ClientCredentialsAndAudience(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "someToken", token.AccessToken)
}

func TestWrapAuth0ClientInfo(t *testing.T) {
var testCases = []struct {
name string
given Auth0ClientInfo
expected string
}{
{
name: "Default client",
given: *DefaultAuth0ClientInfo,
expected: "eyJuYW1lIjoiZ28tYXV0aDAiLCJ2ZXJzaW9uIjoibGF0ZXN0In0=",
},
{
name: "Custom client",
given: Auth0ClientInfo{"foo", "1.0.0", map[string]string{"os": "windows"}},
expected: "eyJuYW1lIjoiZm9vIiwidmVyc2lvbiI6IjEuMC4wIiwiZW52Ijp7Im9zIjoid2luZG93cyJ9fQ==",
},
{
name: "Missing information",
given: Auth0ClientInfo{Name: "foo"},
expected: "eyJuYW1lIjoiZm9vIiwidmVyc2lvbiI6IiJ9",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Auth0-Client")
assert.Equal(t, testCase.expected, header)
})

testServer := httptest.NewServer(testHandler)
t.Cleanup(func() {
testServer.Close()
})

httpClient := Wrap(testServer.Client(), StaticToken(""), WithAuth0ClientInfo(&testCase.given))
_, err := httpClient.Get(testServer.URL)
assert.NoError(t, err)
})
}
}
29 changes: 16 additions & 13 deletions management/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ type Management struct {
// EmailProvider manages Auth0 Email Providers.
EmailProvider *EmailProviderManager

url *url.URL
basePath string
userAgent string
debug bool
ctx context.Context
tokenSource oauth2.TokenSource
http *http.Client
url *url.URL
basePath string
userAgent string
debug bool
ctx context.Context
tokenSource oauth2.TokenSource
http *http.Client
auth0ClientInfo *client.Auth0ClientInfo
}

// New creates a new Auth0 Management client by authenticating using the
Expand All @@ -133,12 +134,13 @@ func New(domain string, options ...Option) (*Management, error) {
}

m := &Management{
url: u,
basePath: "api/v2",
userAgent: client.UserAgent,
debug: false,
ctx: context.Background(),
http: http.DefaultClient,
url: u,
basePath: "api/v2",
userAgent: client.UserAgent,
debug: false,
ctx: context.Background(),
http: http.DefaultClient,
auth0ClientInfo: client.DefaultAuth0ClientInfo,
}

for _, option := range options {
Expand All @@ -151,6 +153,7 @@ func New(domain string, options ...Option) (*Management, error) {
client.WithDebug(m.debug),
client.WithUserAgent(m.userAgent),
client.WithRateLimit(),
client.WithAuth0ClientInfo(m.auth0ClientInfo),
)

m.Client = newClientManager(m)
Expand Down
18 changes: 18 additions & 0 deletions management/management_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,21 @@ func WithClient(client *http.Client) Option {
m.http = client
}
}

// WithAuth0ClientInfo configures the management client to use the provided client information
// instead of the default one.
func WithAuth0ClientInfo(auth0ClientInfo client.Auth0ClientInfo) Option {
return func(m *Management) {
if !auth0ClientInfo.IsEmpty() {
m.auth0ClientInfo = &auth0ClientInfo
}
}
}

// WithNoAuth0ClientInfo configures the management client to not send the "Auth0-Client" header
// at all.
func WithNoAuth0ClientInfo() Option {
return func(m *Management) {
m.auth0ClientInfo = nil
}
}
60 changes: 60 additions & 0 deletions management/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (

_ "github.com/joho/godotenv/autoload"
"github.com/stretchr/testify/assert"

"github.com/auth0/go-auth0/internal/client"
)

var (
Expand Down Expand Up @@ -254,3 +256,61 @@ func TestManagement_URI(t *testing.T) {
})
}
}

func TestAuth0Client(t *testing.T) {
t.Run("Defaults to the default data", func(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Auth0-Client")
assert.Equal(t, "eyJuYW1lIjoiZ28tYXV0aDAiLCJ2ZXJzaW9uIjoibGF0ZXN0In0=", header)
})
s := httptest.NewServer(h)

m, err := New(
s.URL,
WithInsecure(),
)
assert.NoError(t, err)

_, err = m.User.Read("123")

assert.NoError(t, err)
})

t.Run("Allows passing custom Auth0ClientInfo", func(t *testing.T) {
customClient := client.Auth0ClientInfo{Name: "test-client", Version: "1.0.0"}

h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Auth0-Client")
assert.Equal(t, "eyJuYW1lIjoidGVzdC1jbGllbnQiLCJ2ZXJzaW9uIjoiMS4wLjAifQ==", header)
})
s := httptest.NewServer(h)

m, err := New(
s.URL,
WithInsecure(),
WithAuth0ClientInfo(customClient),
)
assert.NoError(t, err)

_, err = m.User.Read("123")

assert.NoError(t, err)
})

t.Run("Allows disabling Auth0ClientInfo", func(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawHeader := r.Header.Get("Auth0-Client")
assert.Empty(t, rawHeader)
})
s := httptest.NewServer(h)

m, err := New(
s.URL,
WithInsecure(),
WithNoAuth0ClientInfo(),
)
assert.NoError(t, err)
_, err = m.User.Read("123")
assert.NoError(t, err)
})
}

0 comments on commit aa9273b

Please sign in to comment.