Skip to content

Commit

Permalink
Impove addLoginToCookie function
Browse files Browse the repository at this point in the history
  • Loading branch information
anderspitman committed Sep 7, 2024
1 parent 1f2c3da commit 733bd57
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 38 deletions.
33 changes: 7 additions & 26 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ func NewOIDCHandler(db Database, config ServerConfig, tmpl *template.Template, j
prefix, err := db.GetPrefix()
checkErr(err)

loginKeyName := prefix + "login_key"

// draft-ietf-oauth-security-topics-24 2.6
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {

Expand Down Expand Up @@ -329,20 +327,6 @@ func NewOIDCHandler(db Database, config ServerConfig, tmpl *template.Template, j
return
}

loginKeyCookie, err := r.Cookie(loginKeyName)
if err != nil {
w.WriteHeader(401)
io.WriteString(w, "Only logged-in users can access this endpoint")
return
}

parsedLoginKey, err := jose.Parse(loginKeyCookie.Value)
if err != nil {
w.WriteHeader(401)
io.WriteString(w, err.Error())
return
}

clearCookie(r.Host, prefix+"auth_request", w)

parsedAuthReq, err := getJwtFromCookie(prefix+"auth_request", w, r, jose)
Expand All @@ -354,16 +338,13 @@ func NewOIDCHandler(db Database, config ServerConfig, tmpl *template.Template, j

identId := r.Form.Get("identity_id")

idents, _ := getIdentities(db, r)

var identity *Identity
tokIdentsInterface, exists := parsedLoginKey.Get("identities")
if exists {
if tokIdents, ok := tokIdentsInterface.([]*Identity); ok {
for _, ident := range tokIdents {
if ident.Id == identId {
identity = ident
break
}
}
for _, ident := range idents {
if ident.Id == identId {
identity = ident
break
}
}

Expand All @@ -383,7 +364,7 @@ func NewOIDCHandler(db Database, config ServerConfig, tmpl *template.Template, j

uri := domainToUri(r.Host)

newLoginCookie, err := addLoginToCookie(r.Host, db, loginKeyCookie.Value, clientId, newLogin, jose)
newLoginCookie, err := addLoginToCookie(db, r, clientId, newLogin)
if err != nil {
w.WriteHeader(500)
fmt.Fprintf(os.Stderr, err.Error())
Expand Down
2 changes: 1 addition & 1 deletion qr.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func NewQrHandler(db Database, cluster *Cluster, tmpl *template.Template, jose *

for clientId, clientLogins := range share.Logins {
for _, login := range clientLogins {
cookie, err = addLoginToCookie(r.Host, db, cookie.Value, clientId, login, jose)
cookie, err = addLoginToCookie(db, r, clientId, login)
if err != nil {
w.WriteHeader(500)
fmt.Fprintf(os.Stderr, err.Error())
Expand Down
31 changes: 20 additions & 11 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,21 @@ func addIdentToCookie(domain string, db Database, cookieValue string, newIdent *
return cookie, nil
}

func addLoginToCookie(domain string, db Database, currentCookieValue, clientId string, newLogin *Login, jose *JOSE) (*http.Cookie, error) {
func addLoginToCookie(db Database, r *http.Request, clientId string, newLogin *Login) (*http.Cookie, error) {

domain := r.Host

prefix, err := db.GetPrefix()
if err != nil {
return nil, err
}

loginKeyName := prefix + "login_key"

loginKeyCookie, err := r.Cookie(loginKeyName)
if err != nil {
return nil, errors.New("Only logged-in users can access this endpoint")
}

issuedAt := time.Now().UTC()

Expand All @@ -235,8 +249,10 @@ func addLoginToCookie(domain string, db Database, currentCookieValue, clientId s

keyJwt := NewJWT()

currentCookieValue := loginKeyCookie.Value

if currentCookieValue != "" {
parsed, err := jose.Parse(currentCookieValue)
parsed, err := ParseJWT(db, currentCookieValue)
if err != nil {
// Only add identities from current cookie if it's valid
} else {
Expand Down Expand Up @@ -267,7 +283,7 @@ func addLoginToCookie(domain string, db Database, currentCookieValue, clientId s
logins[clientId] = []*Login{newLogin}
}

err := keyJwt.Set("iat", issuedAt)
err = keyJwt.Set("iat", issuedAt)
if err != nil {
return nil, err
}
Expand All @@ -286,7 +302,7 @@ func addLoginToCookie(domain string, db Database, currentCookieValue, clientId s
return nil, err
}

signed, err := jose.Sign(keyJwt)
signed, err := SignJWT(db, keyJwt)
if err != nil {
return nil, err
}
Expand All @@ -298,13 +314,6 @@ func addLoginToCookie(domain string, db Database, currentCookieValue, clientId s
return nil, err
}

prefix, err := db.GetPrefix()
if err != nil {
return nil, err
}

loginKeyName := prefix + "login_key"

cookie := &http.Cookie{
Domain: cookieDomain,
Name: loginKeyName,
Expand Down

0 comments on commit 733bd57

Please sign in to comment.