Skip to content

Commit

Permalink
fix: excluded code
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag committed Sep 20, 2022
1 parent 879b76a commit 346fa1e
Show file tree
Hide file tree
Showing 9 changed files with 949 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
vendor
.idea
pkg/oidc
57 changes: 57 additions & 0 deletions pkg/oidc/authorize_callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package oidc

import (
"context"
"net/http"

auth "github.com/formancehq/auth/pkg"
"github.com/formancehq/auth/pkg/delegatedauth"
"github.com/google/uuid"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/op"
)

func authorizeCallbackHandler(
provider op.OpenIDProvider,
storage Storage,
relyingParty rp.RelyingParty,
) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

state, err := delegatedauth.DecodeDelegatedState(r.URL.Query().Get("state"))
if err != nil {
panic(err)
}

authRequest, err := storage.FindAuthRequest(context.Background(), state.AuthRequestID)
if err != nil {
panic(err)
}

tokens, err := rp.CodeExchange(context.Background(), r.URL.Query().Get("code"), relyingParty)
if err != nil {
panic(err)
}

user, err := storage.FindUserBySubject(r.Context(), tokens.IDTokenClaims.GetSubject())
if err != nil {
user = &auth.User{
ID: uuid.NewString(),
Subject: tokens.IDTokenClaims.GetSubject(),
Email: tokens.IDTokenClaims.GetEmail(),
}
if err := storage.SaveUser(r.Context(), *user); err != nil {
panic(err)
}
}

authRequest.UserID = user.ID

if err := storage.UpdateAuthRequest(r.Context(), *authRequest); err != nil {
panic(err)
}

w.Header().Set("Location", op.AuthCallbackURL(provider)(state.AuthRequestID))
w.WriteHeader(http.StatusFound)
}
}
120 changes: 120 additions & 0 deletions pkg/oidc/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package oidc

import (
"time"

auth "github.com/formancehq/auth/pkg"
"github.com/formancehq/auth/pkg/delegatedauth"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
)

type clientFacade struct {
Client auth.Client
relyingParty rp.RelyingParty
}

func NewClientFacade(client auth.Client, relyingParty rp.RelyingParty) *clientFacade {
return &clientFacade{
Client: client,
relyingParty: relyingParty,
}
}

// GetID must return the client_id
func (c *clientFacade) GetID() string {
return c.Client.Id
}

// RedirectURIs must return the registered redirect_uris for Code and Implicit Flow
func (c *clientFacade) RedirectURIs() []string {
return c.Client.RedirectURIs
}

// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs
func (c *clientFacade) PostLogoutRedirectURIs() []string {
return c.Client.PostLogoutRedirectUris
}

// ApplicationType must return the type of the client (app, native, user agent)
func (c *clientFacade) ApplicationType() op.ApplicationType {
return c.Client.ApplicationType
}

// AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt)
func (c *clientFacade) AuthMethod() oidc.AuthMethod {
return c.Client.AuthMethod
}

// ResponseTypes must return all allowed response types (code, id_token token, id_token)
// these must match with the allowed grant types
func (c *clientFacade) ResponseTypes() []oidc.ResponseType {
return c.Client.ResponseTypes
}

// GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer)
func (c *clientFacade) GrantTypes() []oidc.GrantType {
return c.Client.GrantTypes
}

// LoginURL will be called to redirect the user (agent) to the login UI
// you could implement some logic here to redirect the users to different login UIs depending on the client
func (c *clientFacade) LoginURL(id string) string {
return rp.AuthURL(delegatedauth.DelegatedState{
AuthRequestID: id,
}.EncodeAsUrlParam(), c.relyingParty)
}

// AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT)
func (c *clientFacade) AccessTokenType() op.AccessTokenType {
return c.Client.AccessTokenType
}

// IDTokenLifetime must return the lifetime of the client's id_tokens
func (c *clientFacade) IDTokenLifetime() time.Duration {
return 1 * time.Hour
}

// DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client)
func (c *clientFacade) DevMode() bool {
return c.Client.DevMode
}

// RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token
func (c *clientFacade) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
return func(scopes []string) []string {
return scopes
}
}

// RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token
func (c *clientFacade) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
return func(scopes []string) []string {
return scopes
}
}

// IsScopeAllowed enables Client specific custom scopes validation
func (c *clientFacade) IsScopeAllowed(label string) bool {
for _, scope := range c.Client.Scopes {
if scope.Label == label {
return true
}
}
return false
}

// IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token
// even if an access token if issued which violates the OIDC Core spec
// (5.4. Requesting Claims using Scope Values: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims)
// some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued
func (c *clientFacade) IDTokenUserinfoClaimsAssertion() bool {
return c.Client.IdTokenUserinfoClaimsAssertion
}

// ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations
// (subtract from issued_at, add to expiration, ...)
func (c *clientFacade) ClockSkew() time.Duration {
return c.Client.ClockSkew
}
34 changes: 34 additions & 0 deletions pkg/oidc/module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package oidc

import (
"context"

auth "github.com/formancehq/auth/pkg"
"github.com/gorilla/mux"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/op"
"go.uber.org/fx"
)

func Module(addr, issuer string) fx.Option {
return fx.Options(
fx.Provide(NewRouter),
fx.Provide(fx.Annotate(func(storage Storage, relyingParty rp.RelyingParty, opts []auth.ClientOptions) *storageFacade {
var staticClients []auth.Client
for _, c := range opts {
staticClients = append(staticClients, *auth.NewClient(c))
}
return NewStorageFacade(storage, relyingParty, staticClients...)
}, fx.As(new(op.Storage)))),
fx.Provide(func(storage op.Storage) (op.OpenIDProvider, error) {
return NewOpenIDProvider(context.TODO(), storage, issuer)
}),
fx.Invoke(func(lc fx.Lifecycle, router *mux.Router) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return StartServer(addr, router)
},
})
}),
)
}
119 changes: 119 additions & 0 deletions pkg/oidc/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package oidc_test

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"

auth "github.com/formancehq/auth/pkg"
"github.com/formancehq/auth/pkg/oidc"
"github.com/formancehq/auth/pkg/storage/sqlstorage"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/op"
"gorm.io/driver/sqlite"
)

func init() {
os.Setenv(op.OidcDevMode, "true")
}

func withServer(t *testing.T, fn func(storage *sqlstorage.Storage, provider op.OpenIDProvider)) {
// Create a mock OIDC server which will always return a default user
mockOIDC, err := mockoidc.Run()
require.NoError(t, err)
defer func() {
require.NoError(t, mockOIDC.Shutdown())
}()

// Prepare a tcp connection, listening on :0 to select a random port
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

// Compute server url, it will be the "issuer" of our oidc provider
serverUrl := fmt.Sprintf("http://%s", l.Addr().String())

// As our oidc provider, is also a relying party (it delegates authentication), we need to construct a relying party
// with information from the mock
serverRelyingParty, err := rp.NewRelyingPartyOIDC(mockOIDC.Issuer(), mockOIDC.ClientID, mockOIDC.ClientSecret,
fmt.Sprintf("%s/authorize/callback", serverUrl), []string{"openid", "email"})
require.NoError(t, err)

// Construct our storage
db, err := sqlstorage.LoadGorm(sqlite.Open(":memory:"))
require.NoError(t, err)
require.NoError(t, sqlstorage.MigrateTables(context.Background(), db))

storage := sqlstorage.New(db)
storageFacade := oidc.NewStorageFacade(storage, serverRelyingParty)

// Construct our oidc provider
provider, err := oidc.NewOpenIDProvider(context.TODO(), storageFacade, serverUrl)
require.NoError(t, err)

// Create the router
router := oidc.NewRouter(provider, storage, serverRelyingParty)

// Create our http server for our oidc provider
providerHttpServer := &http.Server{
Handler: router,
}
go func() {
err := providerHttpServer.Serve(l)
if err != http.ErrServerClosed {
require.Fail(t, err.Error())
}
}()
defer providerHttpServer.Close()

fn(storage, provider)
}

func Test3LeggedFlow(t *testing.T) {

withServer(t, func(storage *sqlstorage.Storage, provider op.OpenIDProvider) {
// Create ou http server for our client (a web application for example)
code := make(chan string, 1) // Just store codes coming from our provider inside a chan
clientHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
code <- r.URL.Query().Get("code")
})
clientHttpServer := httptest.NewServer(clientHandler)
defer clientHttpServer.Close()

// Create a OAuth2 client which represent our client application
client := auth.NewClient(auth.ClientOptions{})
client.RedirectURIs.Append(clientHttpServer.URL) // Need to configure the redirect uri
_, clear := client.GenerateNewSecret(auth.SecretCreate{}) // Need to generate a secret
require.NoError(t, storage.SaveClient(context.TODO(), *client))

// As our client is a relying party, we can use the library to get some helpers
clientRelyingParty, err := rp.NewRelyingPartyOIDC(provider.Issuer(), client.Id, clear, client.RedirectURIs[0], []string{"openid", "email"})
require.NoError(t, err)

// Trigger an authentication request
authUrl := rp.AuthURL("", clientRelyingParty)
rsp, err := (&http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
fmt.Println(req.URL.String())
return nil
},
}).Get(authUrl)
require.NoError(t, err)
require.Equal(t, http.StatusOK, rsp.StatusCode)

select {
// As the mock automatically accept login response, we should have received a code
case code := <-code:
// And this code is used to get a token
_, err := rp.CodeExchange(context.TODO(), code, clientRelyingParty)
require.NoError(t, err)
default:
require.Fail(t, "code was expected")
}
})
}
27 changes: 27 additions & 0 deletions pkg/oidc/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package oidc

import (
"context"
"crypto/sha256"

"github.com/zitadel/oidc/pkg/op"
"golang.org/x/text/language"
)

const (
pathLoggedOut = "/logged-out"
)

func NewOpenIDProvider(ctx context.Context, storage op.Storage, issuer string) (op.OpenIDProvider, error) {
return op.NewOpenIDProvider(ctx, &op.Config{
Issuer: issuer,
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
}, storage)
}
24 changes: 24 additions & 0 deletions pkg/oidc/router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package oidc

import (
"net/http"

"github.com/gorilla/mux"
"github.com/zitadel/oidc/pkg/client/rp"
"github.com/zitadel/oidc/pkg/op"
"go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux"
)

func NewRouter(provider op.OpenIDProvider, storage Storage,
relyingParty rp.RelyingParty) *mux.Router {
router := provider.HttpHandler().(*mux.Router)
router.Use(otelmux.Middleware("auth"))
router.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
handler.ServeHTTP(w, r)
})
})
router.Path("/authorize/callback").Handler(authorizeCallbackHandler(provider, storage, relyingParty))
return router
}
Loading

0 comments on commit 346fa1e

Please sign in to comment.