diff --git a/internal/config/config.go b/internal/config/config.go index 001ae2a..440a028 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,18 +18,22 @@ type TLS struct { Keys Keys `yaml:"keys"` } +type Cookies struct { + AuthName string `yaml:"authName"` + MessageName string `yaml:"messageName"` +} + type Server struct { - LogLevel string `yaml:"logLevel"` - Addr string `yaml:"addr"` - AuthCookieName string `yaml:"authCookieName"` - MessageCookieName string `yaml:"messageCookieName"` - Secret string `yaml:"secret"` - PrivateKey string `yaml:"privateKey"` - TLS TLS `yaml:"tls"` - LogoutRedirect string `yaml:"logoutRedirect"` - IntrospectScope string `yaml:"introspectScope"` - RevokeScope string `yaml:"revokeScopeScope"` - SessionTimeoutSeconds int `yaml:"sessionTimeoutSeconds"` + LogLevel string `yaml:"logLevel"` + Addr string `yaml:"addr"` + Cookies Cookies `yaml:"cookies"` + Secret string `yaml:"secret"` + PrivateKey string `yaml:"privateKey"` + TLS TLS `yaml:"tls"` + LogoutRedirect string `yaml:"logoutRedirect"` + IntrospectScope string `yaml:"introspectScope"` + RevokeScope string `yaml:"revokeScopeScope"` + SessionTimeoutSeconds int `yaml:"sessionTimeoutSeconds"` } type UserAddress struct { @@ -246,11 +250,11 @@ func (config *Config) GetClient(name string) (*Client, bool) { } func (config *Config) GetAuthCookieName() string { - return GetOrDefaultString(config.Server.AuthCookieName, "stopnik_auth") + return GetOrDefaultString(config.Server.Cookies.AuthName, "stopnik_auth") } func (config *Config) GetMessageCookieName() string { - return GetOrDefaultString(config.Server.MessageCookieName, "stopnik_message") + return GetOrDefaultString(config.Server.Cookies.MessageName, "stopnik_message") } func (config *Config) GetSessionTimeoutSeconds() int { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c63977f..a73af30 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -143,8 +143,10 @@ func simpleServerConfiguration(t *testing.T) { origin := out.(*Config) *origin = Config{ Server: Server{ - Secret: "5XyLSgKpo5kWrJqm", - AuthCookieName: "my_auth", + Secret: "5XyLSgKpo5kWrJqm", + Cookies: Cookies{ + AuthName: "my_auth", + }, IntrospectScope: "i:a", RevokeScope: "r:b", SessionTimeoutSeconds: 4200, diff --git a/internal/crypto/key.go b/internal/crypto/key.go index 18db41f..afd25a9 100644 --- a/internal/crypto/key.go +++ b/internal/crypto/key.go @@ -6,6 +6,9 @@ import ( "encoding/pem" "errors" "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/webishdev/stopnik/internal/config" "os" ) @@ -14,6 +17,34 @@ type SigningPrivateKey struct { SignatureAlgorithm jwa.SignatureAlgorithm } +type ManagedKey struct { + Id string + Clients []*config.Client + Server bool + Key *jwk.Key +} + +type ServerSecretLoader interface { + GetServerSecret() jwt.SignEncryptParseOption +} + +type serverSecret struct { + secret string +} + +type KeyLoader interface { + LoadKeys(client *config.Client) (*ManagedKey, bool) + ServerSecretLoader +} + +func NewServerSecretLoader(config *config.Config) ServerSecretLoader { + return &serverSecret{secret: config.GetServerSecret()} +} + +func (s *serverSecret) GetServerSecret() jwt.SignEncryptParseOption { + return jwt.WithKey(jwa.HS256, []byte(s.secret)) +} + func LoadPrivateKey(name string) (*SigningPrivateKey, error) { privateKeyBytes, readError := os.ReadFile(name) if readError != nil { diff --git a/internal/http/cookie.go b/internal/http/cookie.go index 4512fcf..d849712 100644 --- a/internal/http/cookie.go +++ b/internal/http/cookie.go @@ -1,9 +1,9 @@ package http import ( - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/webishdev/stopnik/internal/config" + "github.com/webishdev/stopnik/internal/crypto" "github.com/webishdev/stopnik/log" "net/http" "time" @@ -12,8 +12,9 @@ import ( type Now func() time.Time type CookieManager struct { - config *config.Config - now Now + config *config.Config + keyFallback crypto.ServerSecretLoader + now Now } func NewCookieManager(config *config.Config) *CookieManager { @@ -21,7 +22,11 @@ func NewCookieManager(config *config.Config) *CookieManager { } func newCookieManagerWithTime(config *config.Config, now Now) *CookieManager { - return &CookieManager{config: config, now: now} + return &CookieManager{ + config: config, + keyFallback: crypto.NewServerSecretLoader(config), + now: now, + } } func (cookieManager *CookieManager) CreateMessageCookie(message string) http.Cookie { @@ -87,7 +92,8 @@ func (cookieManager *CookieManager) ValidateAuthCookie(r *http.Request) (*config } func (cookieManager *CookieManager) validateCookieValue(cookie *http.Cookie) (*config.User, bool) { - token, err := jwt.Parse([]byte(cookie.Value), jwt.WithKey(jwa.HS256, []byte(cookieManager.config.GetServerSecret()))) + options := cookieManager.keyFallback.GetServerSecret() + token, err := jwt.Parse([]byte(cookie.Value), options) if err != nil { return &config.User{}, false } @@ -108,7 +114,8 @@ func (cookieManager *CookieManager) generateCookieValue(username string) (string return "", builderError } - tokenString, tokenError := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte(cookieManager.config.GetServerSecret()))) + options := cookieManager.keyFallback.GetServerSecret() + tokenString, tokenError := jwt.Sign(token, options) if tokenError != nil { return "", tokenError } diff --git a/internal/server/handler/authorize/authorize_test.go b/internal/server/handler/authorize/authorize_test.go index 413f8d1..1ea5f37 100644 --- a/internal/server/handler/authorize/authorize_test.go +++ b/internal/server/handler/authorize/authorize_test.go @@ -278,7 +278,7 @@ func testAuthorizeValidLoginNoSession(t *testing.T, testConfig *config.Config, k requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) cookieManager := internalHttp.NewCookieManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) @@ -389,7 +389,7 @@ func testAuthorizeValidLoginAuthorizationGrant(t *testing.T, testConfig *config. requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) cookieManager := internalHttp.NewCookieManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) @@ -477,7 +477,7 @@ func testAuthorizeValidLoginImplicitGrant(t *testing.T, testConfig *config.Confi requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) cookieManager := internalHttp.NewCookieManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) authorizeHandler := NewAuthorizeHandler(requestValidator, cookieManager, sessionManager, tokenManager, &template.Manager{}) @@ -586,7 +586,7 @@ func testAuthorizeImplicitGrant(t *testing.T, testConfig *config.Config, keyMang requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) cookieManager := internalHttp.NewCookieManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) client, _ := testConfig.GetClient("foo") user, _ := testConfig.GetUser("foo") diff --git a/internal/server/handler/health/health_test.go b/internal/server/handler/health/health_test.go index 9e0de28..bdbe958 100644 --- a/internal/server/handler/health/health_test.go +++ b/internal/server/handler/health/health_test.go @@ -39,7 +39,7 @@ func Test_Health(t *testing.T) { } t.Run("Health without token", func(t *testing.T) { - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) healthHandler := NewHealthHandler(tokenManager) @@ -65,7 +65,7 @@ func Test_Health(t *testing.T) { }) t.Run("Health with token", func(t *testing.T) { - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) client, clientExists := testConfig.GetClient("foo") if !clientExists { diff --git a/internal/server/handler/introspect/introspect_test.go b/internal/server/handler/introspect/introspect_test.go index 5199489..6d9fd54 100644 --- a/internal/server/handler/introspect/introspect_test.go +++ b/internal/server/handler/introspect/introspect_test.go @@ -75,7 +75,7 @@ func Test_Introspect(t *testing.T) { func testIntrospectMissingClientCredentials(t *testing.T, testConfig *config.Config, keyManger *store.KeyManger) { t.Run("Missing client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -92,7 +92,7 @@ func testIntrospectMissingClientCredentials(t *testing.T, testConfig *config.Con func testIntrospectInvalidClientCredentials(t *testing.T, testConfig *config.Config, keyManger *store.KeyManger) { t.Run("Invalid client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -123,7 +123,7 @@ func testIntrospectEmptyToken(t *testing.T, testConfig *config.Config, keyManger testMessage := fmt.Sprintf("Introspect empty %v", test.tokenHint) t.Run(testMessage, func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -169,7 +169,7 @@ func testIntrospectInvalidToken(t *testing.T, testConfig *config.Config, keyMang testMessage := fmt.Sprintf("Introspect invalid %v", test.tokenHint) t.Run(testMessage, func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) introspectHandler := NewIntrospectHandler(testConfig, requestValidator, tokenManager) @@ -234,7 +234,7 @@ func testIntrospect(t *testing.T, testConfig *config.Config, keyManger *store.Ke requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) @@ -306,7 +306,7 @@ func testIntrospectWithoutHint(t *testing.T, testConfig *config.Config, keyMange requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) @@ -377,7 +377,7 @@ func testIntrospectDisabled(t *testing.T, testConfig *config.Config, keyManger * requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManger)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) diff --git a/internal/server/handler/keys/keys.go b/internal/server/handler/keys/keys.go index a53dc7d..d7df335 100644 --- a/internal/server/handler/keys/keys.go +++ b/internal/server/handler/keys/keys.go @@ -3,6 +3,7 @@ package keys import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/webishdev/stopnik/internal/config" + "github.com/webishdev/stopnik/internal/crypto" http2 "github.com/webishdev/stopnik/internal/http" errorHandler "github.com/webishdev/stopnik/internal/server/handler/error" "github.com/webishdev/stopnik/internal/store" @@ -59,7 +60,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (h *Handler) addKey(mangedKey *store.ManagedKey) error { +func (h *Handler) addKey(mangedKey *crypto.ManagedKey) error { key := *mangedKey.Key addKeyError := h.keySet.AddKey(key) diff --git a/internal/server/handler/metadata/metadata.go b/internal/server/handler/metadata/metadata.go index 219ddf3..fa8572b 100644 --- a/internal/server/handler/metadata/metadata.go +++ b/internal/server/handler/metadata/metadata.go @@ -85,6 +85,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { jwa.ES256, jwa.ES384, jwa.ES512, + jwa.HS256, } metadataResponse := &response{ diff --git a/internal/server/handler/revoke/revoke_test.go b/internal/server/handler/revoke/revoke_test.go index d8fe511..2f35fb1 100644 --- a/internal/server/handler/revoke/revoke_test.go +++ b/internal/server/handler/revoke/revoke_test.go @@ -73,7 +73,7 @@ func Test_Revoke(t *testing.T) { func testRevokeMissingClientCredentials(t *testing.T, testConfig *config.Config, keyManager *store.KeyManger) { t.Run("Missing client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -90,7 +90,7 @@ func testRevokeMissingClientCredentials(t *testing.T, testConfig *config.Config, func testRevokeInvalidClientCredentials(t *testing.T, testConfig *config.Config, keyManager *store.KeyManger) { t.Run("Invalid client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -121,7 +121,7 @@ func testRevokeEmptyToken(t *testing.T, testConfig *config.Config, keyManager *s testMessage := fmt.Sprintf("Revoke empty %v", test.tokenHint) t.Run(testMessage, func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -160,7 +160,7 @@ func testRevokeInvalidToken(t *testing.T, testConfig *config.Config, keyManager testMessage := fmt.Sprintf("Revoke invalid %v", test.tokenHint) t.Run(testMessage, func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) revokeHandler := NewRevokeHandler(testConfig, requestValidator, tokenManager) @@ -217,7 +217,7 @@ func testRevoke(t *testing.T, testConfig *config.Config, keyManager *store.KeyMa requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) @@ -294,7 +294,7 @@ func testRevokeWithoutHint(t *testing.T, testConfig *config.Config, keyManager * requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) @@ -370,7 +370,7 @@ func testRevokeDisabled(t *testing.T, testConfig *config.Config, keyManager *sto requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) diff --git a/internal/server/handler/token/token_test.go b/internal/server/handler/token/token_test.go index ed8a24c..6ec16e0 100644 --- a/internal/server/handler/token/token_test.go +++ b/internal/server/handler/token/token_test.go @@ -75,7 +75,7 @@ func testTokenMissingClientCredentials(t *testing.T, testConfig *config.Config, t.Run("Missing client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -93,7 +93,7 @@ func testTokenInvalidClientCredentials(t *testing.T, testConfig *config.Config, t.Run("Invalid client credentials", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -114,7 +114,7 @@ func testTokenMissingGrandType(t *testing.T, testConfig *config.Config, keyManag t.Run("Missing grant type", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -135,7 +135,7 @@ func testTokenInvalidGrandType(t *testing.T, testConfig *config.Config, keyManag t.Run("Invalid grant type", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -162,7 +162,7 @@ func testTokenAuthorizationCodeGrantTypeMissingCodeParameter(t *testing.T, testC t.Run("Authorization code grant type, missing code parameter", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -203,7 +203,7 @@ func testTokenAuthorizationCodeGrantTypeInvalidPKCE(t *testing.T, testConfig *co requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -275,7 +275,7 @@ func testTokenAuthorizationCodeGrantType(t *testing.T, testConfig *config.Config requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -316,7 +316,7 @@ func testTokenPasswordGrantType(t *testing.T, testConfig *config.Config, keyMana t.Run("Password grant type", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -350,7 +350,7 @@ func testTokenClientCredentialsGrantType(t *testing.T, testConfig *config.Config t.Run("Client credentials grant type", func(t *testing.T) { requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) tokenHandler := NewTokenHandler(requestValidator, sessionManager, tokenManager) @@ -399,7 +399,7 @@ func testTokenRefreshTokenGrantType(t *testing.T, testConfig *config.Config, key requestValidator := validation.NewRequestValidator(testConfig) sessionManager := store.NewSessionManager(testConfig) - tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(keyManager)) + tokenManager := store.NewTokenManager(testConfig, store.NewDefaultKeyLoader(testConfig, keyManager)) sessionManager.StartSession(authSession) accessTokenResponse := tokenManager.CreateAccessTokenResponse(user.Username, client, scopes) diff --git a/internal/server/server.go b/internal/server/server.go index 5a3773b..abc5b12 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -178,7 +178,7 @@ func registerHandlers(config *config.Config, handle func(pattern string, handler os.Exit(1) } sessionManager := store.NewSessionManager(config) - tokenManager := store.NewTokenManager(config, store.NewDefaultKeyLoader(keyManger)) + tokenManager := store.NewTokenManager(config, store.NewDefaultKeyLoader(config, keyManger)) cookieManager := internalHttp.NewCookieManager(config) requestValidator := validation.NewRequestValidator(config) templateManager := template.NewTemplateManager(config) diff --git a/internal/store/keys.go b/internal/store/keys.go index 5237e87..ce3a167 100644 --- a/internal/store/keys.go +++ b/internal/store/keys.go @@ -7,23 +7,42 @@ import ( "encoding/base64" "fmt" "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/webishdev/stopnik/internal/config" "github.com/webishdev/stopnik/internal/crypto" ) -type ManagedKey struct { - Id string - Clients []*config.Client - Server bool - Key *jwk.Key +type KeyManger struct { + keyStore *Store[crypto.ManagedKey] } -type KeyManger struct { - keyStore *Store[ManagedKey] +type DefaultKeyLoader struct { + keyFallback crypto.ServerSecretLoader + keyManager *KeyManger +} + +func NewDefaultKeyLoader(config *config.Config, keyManager *KeyManger) *DefaultKeyLoader { + return &DefaultKeyLoader{ + keyFallback: crypto.NewServerSecretLoader(config), + keyManager: keyManager, + } +} + +func (defaultKeyLoader *DefaultKeyLoader) LoadKeys(client *config.Client) (*crypto.ManagedKey, bool) { + key := defaultKeyLoader.keyManager.getClientKey(client) + if key == nil { + return nil, false + } + + return key, true +} + +func (defaultKeyLoader *DefaultKeyLoader) GetServerSecret() jwt.SignEncryptParseOption { + return defaultKeyLoader.keyFallback.GetServerSecret() } func NewKeyManger(config *config.Config) (*KeyManger, error) { - newStore := NewStore[ManagedKey]() + newStore := NewStore[crypto.ManagedKey]() keyManager := &KeyManger{ keyStore: &newStore, } @@ -41,8 +60,8 @@ func NewKeyManger(config *config.Config) (*KeyManger, error) { return keyManager, nil } -func (km *KeyManger) getClientKey(c *config.Client) *ManagedKey { - var result *ManagedKey +func (km *KeyManger) getClientKey(c *config.Client) *crypto.ManagedKey { + var result *crypto.ManagedKey for _, mangedKey := range km.GetAllKeys() { if mangedKey.Server { result = mangedKey @@ -58,7 +77,7 @@ func (km *KeyManger) getClientKey(c *config.Client) *ManagedKey { return result } -func (km *KeyManger) GetAllKeys() []*ManagedKey { +func (km *KeyManger) GetAllKeys() []*crypto.ManagedKey { keyStore := *km.keyStore return keyStore.GetValues() } @@ -103,11 +122,11 @@ func (km *KeyManger) addClientKeys(c *config.Config) error { return nil } -func (km *KeyManger) addManagedKey(managedKey *ManagedKey) { +func (km *KeyManger) addManagedKey(managedKey *crypto.ManagedKey) { keyStore := *km.keyStore existingKey, exists := keyStore.Get(managedKey.Id) if exists { - mergedKey := &ManagedKey{ + mergedKey := &crypto.ManagedKey{ Id: managedKey.Id, Key: managedKey.Key, Server: managedKey.Server || existingKey.Server, @@ -119,7 +138,7 @@ func (km *KeyManger) addManagedKey(managedKey *ManagedKey) { } } -func (km *KeyManger) convert(signingPrivateKey *crypto.SigningPrivateKey) (*ManagedKey, error) { +func (km *KeyManger) convert(signingPrivateKey *crypto.SigningPrivateKey) (*crypto.ManagedKey, error) { keyAsBytes, loadError := km.getBytes(signingPrivateKey.PrivateKey) if loadError != nil { return nil, loadError @@ -146,7 +165,7 @@ func (km *KeyManger) convert(signingPrivateKey *crypto.SigningPrivateKey) (*Mana return nil, setError } - managedKey := &ManagedKey{ + managedKey := &crypto.ManagedKey{ Id: kid, Key: &key, Clients: []*config.Client{}, diff --git a/internal/store/token.go b/internal/store/token.go index d13af45..a557095 100644 --- a/internal/store/token.go +++ b/internal/store/token.go @@ -4,9 +4,9 @@ import ( "encoding/base64" "fmt" "github.com/google/uuid" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/webishdev/stopnik/internal/config" + "github.com/webishdev/stopnik/internal/crypto" internalHttp "github.com/webishdev/stopnik/internal/http" "github.com/webishdev/stopnik/internal/oauth2" "github.com/webishdev/stopnik/log" @@ -16,26 +16,12 @@ import ( type TokenManager struct { config *config.Config - keyLoader KeyLoader + keyLoader crypto.KeyLoader accessTokenStore *ExpiringStore[oauth2.AccessToken] refreshTokenStore *ExpiringStore[oauth2.RefreshToken] } -type KeyLoader interface { - LoadKeys(client *config.Client) (*ManagedKey, bool) -} - -type DefaultKeyLoader struct { - keyManager *KeyManger -} - -func NewDefaultKeyLoader(keyManager *KeyManger) *DefaultKeyLoader { - return &DefaultKeyLoader{ - keyManager: keyManager, - } -} - -func NewTokenManager(config *config.Config, keyLoader KeyLoader) *TokenManager { +func NewTokenManager(config *config.Config, keyLoader crypto.KeyLoader) *TokenManager { accessTokenStore := NewDefaultTimedStore[oauth2.AccessToken]() refreshTokenStore := NewDefaultTimedStore[oauth2.RefreshToken]() return &TokenManager{ @@ -189,8 +175,8 @@ func (tokenManager *TokenManager) generateJWTToken(tokenId string, duration time managedKey, keyExists := loader.LoadKeys(client) if !keyExists { - signKey := jwt.WithKey(jwa.HS256, []byte(tokenManager.config.GetServerSecret())) - tokenString, tokenError := jwt.Sign(token, signKey) + options := loader.GetServerSecret() + tokenString, tokenError := jwt.Sign(token, options) if tokenError != nil { panic(tokenError) } @@ -211,15 +197,6 @@ func (tokenManager *TokenManager) generateJWTToken(tokenId string, duration time } -func (defaultKeyLoader *DefaultKeyLoader) LoadKeys(client *config.Client) (*ManagedKey, bool) { - key := defaultKeyLoader.keyManager.getClientKey(client) - if key == nil { - return nil, false - } - - return key, true -} - func getAuthorizationHeaderValue(authorizationHeader string) *string { if authorizationHeader == "" || !strings.Contains(authorizationHeader, internalHttp.AuthBearer) { return nil diff --git a/internal/store/token_test.go b/internal/store/token_test.go index 5af8a0d..74bb687 100644 --- a/internal/store/token_test.go +++ b/internal/store/token_test.go @@ -27,7 +27,7 @@ func Test_Token(t *testing.T) { t.Run(testMessage, func(t *testing.T) { testConfig := createTestConfig(t, test.opaque, test.refreshTokenTTL) keyManager := createTestKeyManager(t, testConfig) - tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(keyManager)) + tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(testConfig, keyManager)) client, clientExists := testConfig.GetClient("foo") if !clientExists { t.Fatal("client does not exist") @@ -107,7 +107,7 @@ func Test_Token(t *testing.T) { t.Run("Invalid HTTP Authorization header", func(t *testing.T) { testConfig := createTestConfig(t, false, 0) keyManager := createTestKeyManager(t, testConfig) - tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(keyManager)) + tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(testConfig, keyManager)) _, _, valid := tokenManager.ValidateAccessToken("foooo") @@ -119,7 +119,7 @@ func Test_Token(t *testing.T) { t.Run("Invalid Token value", func(t *testing.T) { testConfig := createTestConfig(t, false, 0) keyManager := createTestKeyManager(t, testConfig) - tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(keyManager)) + tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(testConfig, keyManager)) _, _, valid := tokenManager.ValidateAccessToken(fmt.Sprintf("%s %s", internalHttp.AuthBearer, "foo")) @@ -131,7 +131,7 @@ func Test_Token(t *testing.T) { t.Run("Invalid User in token", func(t *testing.T) { testConfig := createTestConfig(t, false, 0) keyManager := createTestKeyManager(t, testConfig) - tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(keyManager)) + tokenManager := NewTokenManager(testConfig, NewDefaultKeyLoader(testConfig, keyManager)) client, clientExists := testConfig.GetClient("foo") if !clientExists { t.Fatal("client does not exist")