diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index 1acd04764df..cb5fc4ca6c4 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -1,10 +1,10 @@ package handles import ( - "encoding/base32" "encoding/base64" "errors" "fmt" + "github.com/Xhofe/go-cache" "net/http" "net/url" "path" @@ -21,19 +21,28 @@ import ( "github.com/coreos/go-oidc" "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" - "github.com/pquerna/otp" - "github.com/pquerna/otp/totp" "golang.org/x/oauth2" "gorm.io/gorm" ) -var opts = totp.ValidateOpts{ - // state verify won't expire in 30 secs, which is quite enough for the callback - Period: 30, - Skew: 1, - // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters - Digits: otp.DigitsEight, - Algorithm: otp.AlgorithmSHA1, +const stateLength = 16 +const stateExpire = time.Minute * 5 + +var stateCache = cache.NewMemCache[string](cache.WithShards[string](stateLength)) + +func _keyState(clientID, state string) string { + return fmt.Sprintf("%s_%s", clientID, state) +} + +func generateState(clientID, ip string) string { + state := random.String(stateLength) + stateCache.Set(_keyState(clientID, state), ip, cache.WithEx[string](stateExpire)) + return state +} + +func verifyState(clientID, ip, state string) bool { + value, ok := stateCache.Get(_keyState(clientID, state)) + return ok && value == ip } func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string { @@ -91,12 +100,7 @@ func SSOLoginRedirect(c *gin.Context) { common.ErrorStrResp(c, err.Error(), 400) return } - // generate state parameter - state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) - if err != nil { - common.ErrorStrResp(c, err.Error(), 400) - return - } + state := generateState(clientId, c.ClientIP()) c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) return default: @@ -192,13 +196,7 @@ func OIDCLoginCallback(c *gin.Context) { common.ErrorResp(c, err, 400) return } - // add state verify process - stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) - if err != nil { - common.ErrorResp(c, err, 400) - return - } - if !stateVerification { + if !verifyState(clientId, c.ClientIP(), c.Query("state")) { common.ErrorStrResp(c, "incorrect or expired state parameter", 400) return }