Skip to content

Commit

Permalink
Pass in scopes to narrow UserInfo returned claims
Browse files Browse the repository at this point in the history
  • Loading branch information
motoki317 committed May 8, 2024
1 parent 5aa7dd4 commit d99dd36
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
13 changes: 7 additions & 6 deletions router/oauth2/token_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,21 @@ func (h *Handler) TokenEndpointHandler(c echo.Context) error {
}

func (h *Handler) issueIDToken(client *model.OAuth2Client, token *model.OAuth2Token, userID uuid.UUID) (string, error) {
// Base claims
claims := jwt.MapClaims{
"iss": h.Origin,
"sub": userID.String(),
"aud": client.ID,
"exp": token.Deadline().Unix(),
"iat": token.CreatedAt.Unix(),
}
if token.Scopes.Contains("profile") {
userInfo, err := h.OIDC.GetUserInfo(userID)
if err != nil {
return "", err
}
claims = utils.MergeMap(userInfo, claims)
// Extra claims according to scopes (profile, email)
userInfo, err := h.OIDC.GetUserInfo(userID, token.Scopes)
if err != nil {
return "", err
}
claims = utils.MergeMap(userInfo, claims)
// Sign to JWT
return jwt2.Sign(claims)
}

Expand Down
13 changes: 12 additions & 1 deletion router/v3/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v3

import (
"context"
"github.com/samber/lo"
"net/http"
"sort"
"time"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/traPtitech/traQ/router/utils"
"github.com/traPtitech/traQ/service/channel"
"github.com/traPtitech/traQ/service/file"
"github.com/traPtitech/traQ/service/oidc"
"github.com/traPtitech/traQ/service/rbac/role"
jwt2 "github.com/traPtitech/traQ/utils/jwt"
"github.com/traPtitech/traQ/utils/optional"
Expand Down Expand Up @@ -117,9 +119,18 @@ func (h *Handlers) GetMe(c echo.Context) error {
})
}

type userAccessScopes struct{}

func (u userAccessScopes) Contains(_ model.AccessScope) bool {
return true
}

// GetMeOIDC GET /users/me/oidc
func (h *Handlers) GetMeOIDC(c echo.Context) error {
userInfo, err := h.OIDC.GetUserInfo(getRequestUserID(c))
tokenScopes, ok := c.Get(consts.KeyOAuth2AccessScopes).(model.AccessScopes)
scopes := lo.Ternary[oidc.ScopeChecker](ok, tokenScopes, userAccessScopes{})

userInfo, err := h.OIDC.GetUserInfo(getRequestUserID(c), scopes)
if err != nil {
return herror.InternalServerError(err)
}
Expand Down
39 changes: 26 additions & 13 deletions service/oidc/userinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ func NewOIDCService(
}
}

func (s *Service) GetUserInfo(userID uuid.UUID) (map[string]any, error) {
type ScopeChecker interface {
Contains(scope model.AccessScope) bool
}

func (s *Service) GetUserInfo(userID uuid.UUID, scopes ScopeChecker) (map[string]any, error) {
user, err := s.repo.GetUser(userID, true)
if err != nil {
return nil, err
Expand All @@ -41,16 +45,19 @@ func (s *Service) GetUserInfo(userID uuid.UUID) (map[string]any, error) {
return nil, err
}

return map[string]any{
// OIDC standard claims
"name": user.GetName(),
"email": user.GetName() + "+dummy@example.com",
"email_verified": false,
"preferred_username": user.GetName(),
"picture": s.origin + "/api/v3/public/icon/" + user.GetName(),
"updated_at": user.GetUpdatedAt(),
// traQ specific claims
"traq": map[string]any{
// Build claims
claims := make(map[string]any)

// Required in UserInfo response
claims["sub"] = userID.String()

// Scope specific claims
if scopes.Contains("profile") {
claims["name"] = user.GetName()
claims["preferred_username"] = user.GetName()
claims["picture"] = s.origin + "/api/v3/public/icon/" + user.GetName()
claims["updated_at"] = user.GetUpdatedAt().Unix()
claims["traq"] = map[string]any{
"bio": user.GetBio(),
"groups": groups,
"tags": utils.Map(tags, func(tag model.UserTag) string { return tag.GetTag() }),
Expand All @@ -62,6 +69,12 @@ func (s *Service) GetUserInfo(userID uuid.UUID) (map[string]any, error) {
"state": user.GetState().Int(),
"permissions": s.rbac.GetGrantedPermissions(user.GetRole()),
"home_channel": user.GetHomeChannel(),
},
}, nil
}
}
if scopes.Contains("email") {
claims["email"] = user.GetName() + "+dummy@example.com"
claims["email_verified"] = false
}

return claims, nil
}

0 comments on commit d99dd36

Please sign in to comment.