-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
949 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
vendor | ||
.idea | ||
pkg/oidc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package oidc | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
|
||
auth "github.com/formancehq/auth/pkg" | ||
"github.com/formancehq/auth/pkg/delegatedauth" | ||
"github.com/google/uuid" | ||
"github.com/zitadel/oidc/pkg/client/rp" | ||
"github.com/zitadel/oidc/pkg/op" | ||
) | ||
|
||
func authorizeCallbackHandler( | ||
provider op.OpenIDProvider, | ||
storage Storage, | ||
relyingParty rp.RelyingParty, | ||
) http.HandlerFunc { | ||
return func(w http.ResponseWriter, r *http.Request) { | ||
|
||
state, err := delegatedauth.DecodeDelegatedState(r.URL.Query().Get("state")) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
authRequest, err := storage.FindAuthRequest(context.Background(), state.AuthRequestID) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
tokens, err := rp.CodeExchange(context.Background(), r.URL.Query().Get("code"), relyingParty) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
user, err := storage.FindUserBySubject(r.Context(), tokens.IDTokenClaims.GetSubject()) | ||
if err != nil { | ||
user = &auth.User{ | ||
ID: uuid.NewString(), | ||
Subject: tokens.IDTokenClaims.GetSubject(), | ||
Email: tokens.IDTokenClaims.GetEmail(), | ||
} | ||
if err := storage.SaveUser(r.Context(), *user); err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
authRequest.UserID = user.ID | ||
|
||
if err := storage.UpdateAuthRequest(r.Context(), *authRequest); err != nil { | ||
panic(err) | ||
} | ||
|
||
w.Header().Set("Location", op.AuthCallbackURL(provider)(state.AuthRequestID)) | ||
w.WriteHeader(http.StatusFound) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package oidc | ||
|
||
import ( | ||
"time" | ||
|
||
auth "github.com/formancehq/auth/pkg" | ||
"github.com/formancehq/auth/pkg/delegatedauth" | ||
"github.com/zitadel/oidc/pkg/client/rp" | ||
"github.com/zitadel/oidc/pkg/oidc" | ||
"github.com/zitadel/oidc/pkg/op" | ||
) | ||
|
||
type clientFacade struct { | ||
Client auth.Client | ||
relyingParty rp.RelyingParty | ||
} | ||
|
||
func NewClientFacade(client auth.Client, relyingParty rp.RelyingParty) *clientFacade { | ||
return &clientFacade{ | ||
Client: client, | ||
relyingParty: relyingParty, | ||
} | ||
} | ||
|
||
// GetID must return the client_id | ||
func (c *clientFacade) GetID() string { | ||
return c.Client.Id | ||
} | ||
|
||
// RedirectURIs must return the registered redirect_uris for Code and Implicit Flow | ||
func (c *clientFacade) RedirectURIs() []string { | ||
return c.Client.RedirectURIs | ||
} | ||
|
||
// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs | ||
func (c *clientFacade) PostLogoutRedirectURIs() []string { | ||
return c.Client.PostLogoutRedirectUris | ||
} | ||
|
||
// ApplicationType must return the type of the client (app, native, user agent) | ||
func (c *clientFacade) ApplicationType() op.ApplicationType { | ||
return c.Client.ApplicationType | ||
} | ||
|
||
// AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt) | ||
func (c *clientFacade) AuthMethod() oidc.AuthMethod { | ||
return c.Client.AuthMethod | ||
} | ||
|
||
// ResponseTypes must return all allowed response types (code, id_token token, id_token) | ||
// these must match with the allowed grant types | ||
func (c *clientFacade) ResponseTypes() []oidc.ResponseType { | ||
return c.Client.ResponseTypes | ||
} | ||
|
||
// GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer) | ||
func (c *clientFacade) GrantTypes() []oidc.GrantType { | ||
return c.Client.GrantTypes | ||
} | ||
|
||
// LoginURL will be called to redirect the user (agent) to the login UI | ||
// you could implement some logic here to redirect the users to different login UIs depending on the client | ||
func (c *clientFacade) LoginURL(id string) string { | ||
return rp.AuthURL(delegatedauth.DelegatedState{ | ||
AuthRequestID: id, | ||
}.EncodeAsUrlParam(), c.relyingParty) | ||
} | ||
|
||
// AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT) | ||
func (c *clientFacade) AccessTokenType() op.AccessTokenType { | ||
return c.Client.AccessTokenType | ||
} | ||
|
||
// IDTokenLifetime must return the lifetime of the client's id_tokens | ||
func (c *clientFacade) IDTokenLifetime() time.Duration { | ||
return 1 * time.Hour | ||
} | ||
|
||
// DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client) | ||
func (c *clientFacade) DevMode() bool { | ||
return c.Client.DevMode | ||
} | ||
|
||
// RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token | ||
func (c *clientFacade) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { | ||
return func(scopes []string) []string { | ||
return scopes | ||
} | ||
} | ||
|
||
// RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token | ||
func (c *clientFacade) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { | ||
return func(scopes []string) []string { | ||
return scopes | ||
} | ||
} | ||
|
||
// IsScopeAllowed enables Client specific custom scopes validation | ||
func (c *clientFacade) IsScopeAllowed(label string) bool { | ||
for _, scope := range c.Client.Scopes { | ||
if scope.Label == label { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
// IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token | ||
// even if an access token if issued which violates the OIDC Core spec | ||
// (5.4. Requesting Claims using Scope Values: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims) | ||
// some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued | ||
func (c *clientFacade) IDTokenUserinfoClaimsAssertion() bool { | ||
return c.Client.IdTokenUserinfoClaimsAssertion | ||
} | ||
|
||
// ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations | ||
// (subtract from issued_at, add to expiration, ...) | ||
func (c *clientFacade) ClockSkew() time.Duration { | ||
return c.Client.ClockSkew | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package oidc | ||
|
||
import ( | ||
"context" | ||
|
||
auth "github.com/formancehq/auth/pkg" | ||
"github.com/gorilla/mux" | ||
"github.com/zitadel/oidc/pkg/client/rp" | ||
"github.com/zitadel/oidc/pkg/op" | ||
"go.uber.org/fx" | ||
) | ||
|
||
func Module(addr, issuer string) fx.Option { | ||
return fx.Options( | ||
fx.Provide(NewRouter), | ||
fx.Provide(fx.Annotate(func(storage Storage, relyingParty rp.RelyingParty, opts []auth.ClientOptions) *storageFacade { | ||
var staticClients []auth.Client | ||
for _, c := range opts { | ||
staticClients = append(staticClients, *auth.NewClient(c)) | ||
} | ||
return NewStorageFacade(storage, relyingParty, staticClients...) | ||
}, fx.As(new(op.Storage)))), | ||
fx.Provide(func(storage op.Storage) (op.OpenIDProvider, error) { | ||
return NewOpenIDProvider(context.TODO(), storage, issuer) | ||
}), | ||
fx.Invoke(func(lc fx.Lifecycle, router *mux.Router) { | ||
lc.Append(fx.Hook{ | ||
OnStart: func(ctx context.Context) error { | ||
return StartServer(addr, router) | ||
}, | ||
}) | ||
}), | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package oidc_test | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"testing" | ||
|
||
auth "github.com/formancehq/auth/pkg" | ||
"github.com/formancehq/auth/pkg/oidc" | ||
"github.com/formancehq/auth/pkg/storage/sqlstorage" | ||
"github.com/oauth2-proxy/mockoidc" | ||
"github.com/stretchr/testify/require" | ||
"github.com/zitadel/oidc/pkg/client/rp" | ||
"github.com/zitadel/oidc/pkg/op" | ||
"gorm.io/driver/sqlite" | ||
) | ||
|
||
func init() { | ||
os.Setenv(op.OidcDevMode, "true") | ||
} | ||
|
||
func withServer(t *testing.T, fn func(storage *sqlstorage.Storage, provider op.OpenIDProvider)) { | ||
// Create a mock OIDC server which will always return a default user | ||
mockOIDC, err := mockoidc.Run() | ||
require.NoError(t, err) | ||
defer func() { | ||
require.NoError(t, mockOIDC.Shutdown()) | ||
}() | ||
|
||
// Prepare a tcp connection, listening on :0 to select a random port | ||
l, err := net.Listen("tcp", "127.0.0.1:0") | ||
require.NoError(t, err) | ||
|
||
// Compute server url, it will be the "issuer" of our oidc provider | ||
serverUrl := fmt.Sprintf("http://%s", l.Addr().String()) | ||
|
||
// As our oidc provider, is also a relying party (it delegates authentication), we need to construct a relying party | ||
// with information from the mock | ||
serverRelyingParty, err := rp.NewRelyingPartyOIDC(mockOIDC.Issuer(), mockOIDC.ClientID, mockOIDC.ClientSecret, | ||
fmt.Sprintf("%s/authorize/callback", serverUrl), []string{"openid", "email"}) | ||
require.NoError(t, err) | ||
|
||
// Construct our storage | ||
db, err := sqlstorage.LoadGorm(sqlite.Open(":memory:")) | ||
require.NoError(t, err) | ||
require.NoError(t, sqlstorage.MigrateTables(context.Background(), db)) | ||
|
||
storage := sqlstorage.New(db) | ||
storageFacade := oidc.NewStorageFacade(storage, serverRelyingParty) | ||
|
||
// Construct our oidc provider | ||
provider, err := oidc.NewOpenIDProvider(context.TODO(), storageFacade, serverUrl) | ||
require.NoError(t, err) | ||
|
||
// Create the router | ||
router := oidc.NewRouter(provider, storage, serverRelyingParty) | ||
|
||
// Create our http server for our oidc provider | ||
providerHttpServer := &http.Server{ | ||
Handler: router, | ||
} | ||
go func() { | ||
err := providerHttpServer.Serve(l) | ||
if err != http.ErrServerClosed { | ||
require.Fail(t, err.Error()) | ||
} | ||
}() | ||
defer providerHttpServer.Close() | ||
|
||
fn(storage, provider) | ||
} | ||
|
||
func Test3LeggedFlow(t *testing.T) { | ||
|
||
withServer(t, func(storage *sqlstorage.Storage, provider op.OpenIDProvider) { | ||
// Create ou http server for our client (a web application for example) | ||
code := make(chan string, 1) // Just store codes coming from our provider inside a chan | ||
clientHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
code <- r.URL.Query().Get("code") | ||
}) | ||
clientHttpServer := httptest.NewServer(clientHandler) | ||
defer clientHttpServer.Close() | ||
|
||
// Create a OAuth2 client which represent our client application | ||
client := auth.NewClient(auth.ClientOptions{}) | ||
client.RedirectURIs.Append(clientHttpServer.URL) // Need to configure the redirect uri | ||
_, clear := client.GenerateNewSecret(auth.SecretCreate{}) // Need to generate a secret | ||
require.NoError(t, storage.SaveClient(context.TODO(), *client)) | ||
|
||
// As our client is a relying party, we can use the library to get some helpers | ||
clientRelyingParty, err := rp.NewRelyingPartyOIDC(provider.Issuer(), client.Id, clear, client.RedirectURIs[0], []string{"openid", "email"}) | ||
require.NoError(t, err) | ||
|
||
// Trigger an authentication request | ||
authUrl := rp.AuthURL("", clientRelyingParty) | ||
rsp, err := (&http.Client{ | ||
CheckRedirect: func(req *http.Request, via []*http.Request) error { | ||
fmt.Println(req.URL.String()) | ||
return nil | ||
}, | ||
}).Get(authUrl) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusOK, rsp.StatusCode) | ||
|
||
select { | ||
// As the mock automatically accept login response, we should have received a code | ||
case code := <-code: | ||
// And this code is used to get a token | ||
_, err := rp.CodeExchange(context.TODO(), code, clientRelyingParty) | ||
require.NoError(t, err) | ||
default: | ||
require.Fail(t, "code was expected") | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package oidc | ||
|
||
import ( | ||
"context" | ||
"crypto/sha256" | ||
|
||
"github.com/zitadel/oidc/pkg/op" | ||
"golang.org/x/text/language" | ||
) | ||
|
||
const ( | ||
pathLoggedOut = "/logged-out" | ||
) | ||
|
||
func NewOpenIDProvider(ctx context.Context, storage op.Storage, issuer string) (op.OpenIDProvider, error) { | ||
return op.NewOpenIDProvider(ctx, &op.Config{ | ||
Issuer: issuer, | ||
CryptoKey: sha256.Sum256([]byte("test")), | ||
DefaultLogoutRedirectURI: pathLoggedOut, | ||
CodeMethodS256: true, | ||
AuthMethodPost: true, | ||
AuthMethodPrivateKeyJWT: true, | ||
GrantTypeRefreshToken: true, | ||
RequestObjectSupported: true, | ||
SupportedUILocales: []language.Tag{language.English}, | ||
}, storage) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package oidc | ||
|
||
import ( | ||
"net/http" | ||
|
||
"github.com/gorilla/mux" | ||
"github.com/zitadel/oidc/pkg/client/rp" | ||
"github.com/zitadel/oidc/pkg/op" | ||
"go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" | ||
) | ||
|
||
func NewRouter(provider op.OpenIDProvider, storage Storage, | ||
relyingParty rp.RelyingParty) *mux.Router { | ||
router := provider.HttpHandler().(*mux.Router) | ||
router.Use(otelmux.Middleware("auth")) | ||
router.Use(func(handler http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Set("Content-Type", "application/json") | ||
handler.ServeHTTP(w, r) | ||
}) | ||
}) | ||
router.Path("/authorize/callback").Handler(authorizeCallbackHandler(provider, storage, relyingParty)) | ||
return router | ||
} |
Oops, something went wrong.