Skip to content

Commit

Permalink
oauth2-client-update (#70)
Browse files Browse the repository at this point in the history
tokenclient update
Authored-by: flaszlo2
  • Loading branch information
flaszlo2 authored Jun 6, 2024
1 parent 0b82aed commit 899d72c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
8 changes: 6 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,14 @@ func (c *Client) SetBasicAuth(username, password string) *Client {
return c
}

// SetOauth2Conf makes client obtain OAuth2 access token with given grant.
// SetOauth2Conf initializes OAuth2 configuration with given grant.
// Depending on specific setup, custom http.Client can be added to obtain access tokens.
// Either on first request to be sent or later when the obtained access token is expired.
//
// Make sure encrypted transport is used, e.g. the link is https.
// If client's HTTPS() has been called earlier, then token URL is checked accordingly.
// If token URL does not meet those requirements, then client credentials auth is not activated and error log is printed.
func (c *Client) SetOauth2Conf(config oauth2.Config, grant ...Grant) *Client {
func (c *Client) SetOauth2Conf(config oauth2.Config, tokenClient *http.Client, grant ...Grant) *Client {
if c.httpsCfg != nil {
tokenURL, err := url.Parse(config.Endpoint.TokenURL)
if err == nil {
Expand All @@ -368,6 +369,9 @@ func (c *Client) SetOauth2Conf(config oauth2.Config, grant ...Grant) *Client {
}
}
c.oauth2.config = &config
if c.oauth2.client == nil && tokenClient != nil {
c.oauth2.client = tokenClient
}
return c
}

Expand Down
16 changes: 8 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,23 +618,23 @@ func TestCtxCancelBefore(t *testing.T) {
}

func TestSetClientCredentialAuthDown(t *testing.T) {
client := NewClient().HTTPS(nil).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "https://0.0.0.0:1"}})
client := NewClient().HTTPS(nil).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "https://0.0.0.0:1"}}, nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := client.Get(ctx, "https://127.0.0.1", nil)
assert.Contains(t, err.Error(), "0.0.0.0:1")
}

func TestSetClientCredentialAuthDownAllowedTarget(t *testing.T) {
client := NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"0.0.0.0"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "https://0.0.0.0:1"}})
client := NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"0.0.0.0"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "https://0.0.0.0:1"}}, nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := client.Get(ctx, "https://127.0.0.1", nil)
assert.Contains(t, err.Error(), "0.0.0.0:1")
}

func TestSetClientCredentialNotAllowedTarget(t *testing.T) {
client := NewClient().HTTPS(nil).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "http://0.0.0.0:1"}})
client := NewClient().HTTPS(nil).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: "http://0.0.0.0:1"}}, nil)
assert.Nil(t, client.oauth2.config)
assert.NotNil(t, client)
}
Expand Down Expand Up @@ -662,13 +662,13 @@ func TestOauth2AccessTokenReqs(t *testing.T) {
req, _ := http.NewRequest("GET", "http://127.0.0.1", nil)

// Test client with invalid grant, defaulting to client credentials
client := NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}, "garbage")
client := NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}, nil, "garbage")
err := client.setOauth2Auth(ctx, req)
assert.NoError(t, err)
assert.Equal(t, client.oauth2.token.AccessToken, accesToken)

// Test client with password credentials grant
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}, GrantPasswordCredentials)
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}, nil, GrantPasswordCredentials)
assert.Equal(t, len(client.oauth2.config.Scopes), 0)
client.SetBasicAuth("user", "pass")
err = client.setOauth2Auth(ctx, req)
Expand All @@ -678,7 +678,7 @@ func TestOauth2AccessTokenReqs(t *testing.T) {
req.Header.Del("Authorization")

// Test client with refresh token grant
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}, Scopes: []string{"openid", "profile"}}, GrantRefreshToken)
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}, Scopes: []string{"openid", "profile"}}, nil, GrantRefreshToken)
assert.Equal(t, len(client.oauth2.config.Scopes), 2)
client.SetBasicAuth("user", "pass")
assert.Equal(t, client.oauth2.token.RefreshToken, "")
Expand All @@ -694,14 +694,14 @@ func TestOauth2AccessTokenReqs(t *testing.T) {
req.Header.Del("Authorization")

// Test client with default client credentials grant
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}, Scopes: []string{"openid", "profile"}})
client = NewClient().HTTPS(&HTTPSConfig{AllowedHTTPHosts: []string{"127.0.0.1"}}).SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}, Scopes: []string{"openid", "profile"}}, nil)
assert.Equal(t, len(client.oauth2.config.Scopes), 2)
err = client.setOauth2Auth(ctx, req)
assert.NoError(t, err)
assert.Equal(t, client.oauth2.token.AccessToken, accesToken)

// Test h2 OAuth2 client
client = NewClient().SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}).SetOauth2H2()
client = NewClient().SetOauth2Conf(oauth2.Config{ClientID: "id", ClientSecret: "secret", Endpoint: oauth2.Endpoint{TokenURL: authSrv.URL}}, nil).SetOauth2H2()
err = client.setOauth2Auth(ctx, req)
assert.Error(t, err) // h2 is not allowed for clear text http URL.
}
Expand Down

0 comments on commit 899d72c

Please sign in to comment.