diff --git a/pkg/idp/oauth/validator.go b/pkg/idp/oauth/validator.go index 42df3f6..e17bfba 100644 --- a/pkg/idp/oauth/validator.go +++ b/pkg/idp/oauth/validator.go @@ -119,14 +119,27 @@ func (b *IdentityProvider) validateAccessToken(state string, data map[string]int case "cognito": if v, exists := data["id_token"]; exists { if tp, err := kms.ParsePayloadFromToken(v.(string)); err == nil { + roles := []string{} for k, val := range tp { switch k { - case "custom:roles": - roles := []string{} - for _, roleName := range strings.Split(val.(string), "|") { - roles = append(roles, roleName) + case "custom:roles", "cognito:groups", "cognito:roles": + switch values := v.(type) { + case string: + if k == "custom:roles" { + for _, roleName := range strings.Split(val.(string), "|") { + roles = append(roles, roleName) + } + } else { + roles = append(roles, values) + } + case []interface{}: + for _, value := range values { + switch roleName := value.(type) { + case string: + roles = append(roles, roleName) + } + } } - m["roles"] = roles case "custom:timezone": m["timezone"] = val.(string) case "cognito:username": @@ -135,6 +148,9 @@ func (b *IdentityProvider) validateAccessToken(state string, data map[string]int m["timezone"] = val.(string) } } + if len(roles) > 0 { + m["roles"] = roles + } } } }