Skip to content

Commit

Permalink
Add explicit support for User/PW, client credentials and Bearer token…
Browse files Browse the repository at this point in the history
… oidc auth flows (#101)
  • Loading branch information
dirkkul authored Jan 31, 2023
1 parent 2ceca4d commit 38d0af8
Show file tree
Hide file tree
Showing 47 changed files with 1,177 additions and 185 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,16 @@ jobs:
tests-v4:
name: Tests v4
runs-on: ubuntu-latest
strategy:
matrix:
auth_integration: [ "auth_enabled", "auth_disabled" ]
env:
EXTERNAL_WEAVIATE_RUNNING: false
AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }}
WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }}
OKTA_DUMMY_CI_PW: ${{ secrets.OKTA_DUMMY_CI_PW }}
INTEGRATION_TESTS_AUTH: ${{ matrix.auth_integration }}
steps:
- uses: actions/checkout@v3
- name: Login to Docker Hub
Expand All @@ -82,9 +90,12 @@ jobs:
with:
go-version: 1.19
cache: true
- name: Start Weaviate
run: ./v4/test/start_containers.sh
- name: Run tests
run: |
cd v4
docker-compose -f test/docker-compose.yaml up -d
go test -v ./weaviate/...
( for pkg in $(go list ./... | grep 'weaviate-go-client/v4/test'); do if ! go test -v -count 1 -race "$pkg"; then echo "Test for $pkg failed" >&2; false; exit; fi; done)
- name: Stop Weaviate
run: ./v4/test/stop_containers.sh
1 change: 1 addition & 0 deletions v4/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ require (
github.com/go-openapi/strfmt v0.21.3
github.com/stretchr/testify v1.8.0
github.com/weaviate/weaviate v1.17.2-0.20230118094121-abf30eac8656
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
)
2 changes: 2 additions & 0 deletions v4/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,7 @@ golang.org/x/oauth2 v0.0.0-20220309155454-6242fa91716a/go.mod h1:DAh4E804XQdzx2j
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc=
golang.org/x/oauth2 v0.0.0-20220608161450-d0670ef3b1eb/go.mod h1:jaDAt6Dkxork7LmZnYtzbRWj0W47D86a3TGe0YHBvmE=
golang.org/x/oauth2 v0.0.0-20220622183110-fd043fe589d2/go.mod h1:jaDAt6Dkxork7LmZnYtzbRWj0W47D86a3TGe0YHBvmE=
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094 h1:2o1E+E8TpNLklK9nHiPiK1uzIYrIHt+cQx3ynCwq9V8=
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094/go.mod h1:h4gKUeWbJ4rQPri7E0u6Gs4e9Ri2zaLxzw5DI5XGrYg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down Expand Up @@ -1341,6 +1342,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/cloud v0.0.0-20151119220103-975617b05ea8/go.mod h1:0H1ncTHf11KCFhTc/+EFRbzSCOZx+VUbRMk55Yv5MYk=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
Expand Down
267 changes: 267 additions & 0 deletions v4/test/auth/auth_mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package test

import (
"bytes"
"context"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate-go-client/v4/weaviate/auth"
)

const (
AccessToken = "HELLO!IamAnAccessToken"
RefreshToken = "IAmARefreshToken"
)

// Test that the client warns when no refresh token is provided by the authentication provider
func TestAuthMock_NoRefreshToken(t *testing.T) {
tests := []struct {
name string
authConfig auth.Config
scope []string
}{
{name: "User/PW", authConfig: auth.ResourceOwnerPasswordFlow{Username: "SomeUsername", Password: "IamWrong"}},
{name: "Bearer token", authConfig: auth.BearerToken{AccessToken: "NotAToken"}},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// write log to buffer
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()

// endpoint for access tokens
muxToken := http.NewServeMux()
muxToken.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprint(`{"access_token": "` + AccessToken + `", "expires_in": "5"}`)))
})
sToken := httptest.NewServer(muxToken)
defer sToken.Close()

// provides all endpoints
muxEndpoints := http.NewServeMux()
muxEndpoints.HandleFunc("/endpoints", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(`{"token_endpoint": "` + sToken.URL + `/auth"}`)))
})
sEndpoints := httptest.NewServer(muxEndpoints)
defer sEndpoints.Close()

// Returns the address of the auth server
mux := http.NewServeMux()
mux.HandleFunc("/v1/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"href": "` + sEndpoints.URL + `/endpoints", "clientId": "DoesNotMatter"}`))
})
mux.HandleFunc("/v1/schema", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{}`))
})
s := httptest.NewServer(mux)
defer s.Close()

cfg, err := weaviate.NewConfig(strings.TrimPrefix(s.URL, "http://"), "http", tc.authConfig, nil)
assert.Nil(t, err)
assert.True(t, strings.Contains(buf.String(), "Auth002"))

client := weaviate.New(*cfg)
AuthErr := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr)
})
}
}

// Test that client using CC automatically get a new token after expiration
func TestAuthMock_RefreshCC(t *testing.T) {
i := 0
// endpoint for access tokens
muxToken := http.NewServeMux()
muxToken.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
i += 1 // record how often this was called
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprint(`{"access_token": "` + AccessToken + `", "expires_in": "1"}`)))
})
sToken := httptest.NewServer(muxToken)
defer sToken.Close()

// provides all endpoints
muxEndpoints := http.NewServeMux()
muxEndpoints.HandleFunc("/endpoints", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(`{"token_endpoint": "` + sToken.URL + `/auth"}`)))
})
sEndpoints := httptest.NewServer(muxEndpoints)
defer sEndpoints.Close()

// Returns the address of the auth server
mux := http.NewServeMux()
mux.HandleFunc("/v1/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"href": "` + sEndpoints.URL + `/endpoints", "clientId": "DoesNotMatter"}`))
})
mux.HandleFunc("/v1/schema", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{}`))
})
s := httptest.NewServer(mux)
defer s.Close()

cfg, err := weaviate.NewConfig(strings.TrimPrefix(s.URL, "http://"), "http", auth.ClientCredentials{ClientSecret: "SecretValue"}, nil)
assert.Nil(t, err)
client := weaviate.New(*cfg)
AuthErr := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr)
assert.Equal(t, i, 3) // client does 3 initial calls to token endpoint

time.Sleep(time.Second * 5)
// current token expires, so the oauth client needs to get a new one
AuthErr2 := client.Schema().AllDeleter().Do(context.TODO())
assert.Equal(t, i, 4)
assert.Nil(t, AuthErr2)
}

// Test that client uses refresh tokens to get new access/refresh tokens before their expiration, including during idle
// times.
func TestAuthMock_RefreshUserPWAndToken(t *testing.T) {
expirationTimeRefreshToken := 3
expirationTimeToken := uint(2)
tests := []struct {
name string
authConfig auth.Config
scope []string
}{
{name: "User/PW", authConfig: auth.ResourceOwnerPasswordFlow{Username: "SomeUsername", Password: "IamWrong"}},
{name: "Bearer token", authConfig: auth.BearerToken{
AccessToken: AccessToken, ExpiresIn: expirationTimeToken, RefreshToken: RefreshToken,
}},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tokenRefreshTime := time.Now()
// endpoint for access tokens
muxToken := http.NewServeMux()
muxToken.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
// refresh token cannot be expired
assert.True(t, time.Now().Sub(tokenRefreshTime).Seconds() < float64(expirationTimeRefreshToken))

tokenRefreshTime = time.Now() // update time when the tokens where refreshed the last time
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(
fmt.Sprintf(`{"access_token": "%v", "expires_in": %v, "refresh_token": "%v", "refresh_expires_in" : %v}`,
AccessToken, expirationTimeToken, RefreshToken, expirationTimeRefreshToken)))
})
sToken := httptest.NewServer(muxToken)
defer sToken.Close()

// provides all endpoints
muxEndpoints := http.NewServeMux()
muxEndpoints.HandleFunc("/endpoints", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(`{"token_endpoint": "` + sToken.URL + `/auth"}`)))
})
sEndpoints := httptest.NewServer(muxEndpoints)
defer sEndpoints.Close()

// Returns the address of the auth server
mux := http.NewServeMux()
mux.HandleFunc("/v1/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"href": "` + sEndpoints.URL + `/endpoints", "clientId": "DoesNotMatter"}`))
})
mux.HandleFunc("/v1/schema", func(w http.ResponseWriter, r *http.Request) {
// Access Token cannot be expired
assert.True(t, time.Now().Sub(tokenRefreshTime).Seconds() < float64(expirationTimeToken))
w.Write([]byte(`{}`))
})
s := httptest.NewServer(mux)
defer s.Close()

cfg, err := weaviate.NewConfig(strings.TrimPrefix(s.URL, "http://"), "http", tc.authConfig, nil)
assert.Nil(t, err)
client := weaviate.New(*cfg)
AuthErr := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr)

// access and refresh token expired, so the client needs to refresh automatically in the background
time.Sleep(time.Second * 5)
AuthErr2 := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr2)
})
}
}

// Test that the client can handle situations in which a proxy returns a catchall page for all requests
func TestAuthMock_CatchAllProxy(t *testing.T) {
// write log to buffer
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()

// Simulate a proxy that returns something if a page is not available => no valid json
mux := http.NewServeMux()
mux.HandleFunc("/v1/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`NotAValidJsonResponse`))
})
mux.HandleFunc("/v1/schema", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{}`))
})
s := httptest.NewServer(mux)
defer s.Close()

cfg, err := weaviate.NewConfig(strings.TrimPrefix(s.URL, "http://"), "http", nil, nil)
assert.Nil(t, err)
client := weaviate.New(*cfg)
AuthErr := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr)
}

// Test that client using CC automatically get a new token after expiration
func TestAuthMock_CheckDefaultScopes(t *testing.T) {
// endpoint for access tokens
muxToken := http.NewServeMux()
muxToken.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
bodyS := string(body)
assert.Equal(t, bodyS[len(bodyS)-15:], "something+extra") // scopes are in the body

w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprint(`{"access_token": "` + AccessToken + `", "expires_in": "1"}`)))
})
sToken := httptest.NewServer(muxToken)
defer sToken.Close()

// provides all endpoints
muxEndpoints := http.NewServeMux()
muxEndpoints.HandleFunc("/endpoints", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(`{"token_endpoint": "` + sToken.URL + `/auth"}`)))
})
sEndpoints := httptest.NewServer(muxEndpoints)
defer sEndpoints.Close()

// Returns the address of the auth server
mux := http.NewServeMux()
mux.HandleFunc("/v1/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"href": "` + sEndpoints.URL + `/endpoints", "clientId": "DoesNotMatter", "scopes": ["something", "extra"]}`))
})
mux.HandleFunc("/v1/schema", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{}`))
})
s := httptest.NewServer(mux)
defer s.Close()

cfg, err := weaviate.NewConfig(strings.TrimPrefix(s.URL, "http://"), "http", auth.ClientCredentials{ClientSecret: "SecretValue"}, nil)
assert.Nil(t, err)
client := weaviate.New(*cfg)
AuthErr := client.Schema().AllDeleter().Do(context.TODO())
assert.Nil(t, AuthErr)
}
Loading

0 comments on commit 38d0af8

Please sign in to comment.