Skip to content

Commit

Permalink
feat: configurable NameID format for SAML provider (supabase#1481)
Browse files Browse the repository at this point in the history
Adds the ability to specify the [NameID
format](https://pkg.go.dev/github.com/crewjam/saml#NameIDFormat) for a
SAML SSO provider's authentication request, as some providers don't
appear to support accepting the persistent format.
  • Loading branch information
hf authored Mar 14, 2024
1 parent c38d646 commit fc51dff
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
8 changes: 6 additions & 2 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)
}

// TODO: fetch new metadata if validUntil < time.Now()

serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */)

authnRequest, err := serviceProvider.MakeAuthenticationRequest(
Expand All @@ -104,6 +102,12 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
return internalServerError("Error creating SAML Authentication Request").WithInternalError(err)
}

// Some IdPs do not support the use of the `persistent` NameID format,
// and require a different format to be sent to work.
if ssoProvider.SAMLProvider.NameIDFormat != nil {
authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat
}

relayState := models.SAMLRelayState{
SSOProviderID: ssoProvider.ID,
RequestID: authnRequest.ID,
Expand Down
39 changes: 37 additions & 2 deletions internal/api/ssoadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/url"
"strings"
"unicode/utf8"

"github.com/crewjam/saml"
Expand Down Expand Up @@ -74,6 +75,7 @@ type CreateSSOProviderParams struct {
MetadataXML string `json:"metadata_xml"`
Domains []string `json:"domains"`
AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"`
NameIDFormat string `json:"name_id_format"`
}

func (p *CreateSSOProviderParams) validate(forUpdate bool) error {
Expand All @@ -94,8 +96,22 @@ func (p *CreateSSOProviderParams) validate(forUpdate bool) error {
}
}

// TODO validate p.AttributeMapping
// TODO validate domains
switch p.NameIDFormat {
case "",
string(saml.PersistentNameIDFormat),
string(saml.EmailAddressNameIDFormat),
string(saml.TransientNameIDFormat),
string(saml.UnspecifiedNameIDFormat):
// it's valid

default:
return badRequestError(ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{
string(saml.PersistentNameIDFormat),
string(saml.EmailAddressNameIDFormat),
string(saml.TransientNameIDFormat),
string(saml.UnspecifiedNameIDFormat),
}, ", "))
}

return nil
}
Expand Down Expand Up @@ -217,6 +233,10 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
provider.SAMLProvider.MetadataURL = &params.MetadataURL
}

if params.NameIDFormat != "" {
provider.SAMLProvider.NameIDFormat = &params.NameIDFormat
}

provider.SAMLProvider.AttributeMapping = params.AttributeMapping

for _, domain := range params.Domains {
Expand Down Expand Up @@ -335,6 +355,21 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
provider.SAMLProvider.AttributeMapping = params.AttributeMapping
}

nameIDFormat := ""
if provider.SAMLProvider.NameIDFormat != nil {
nameIDFormat = *provider.SAMLProvider.NameIDFormat
}

if params.NameIDFormat != nameIDFormat {
modified = true

if params.NameIDFormat == "" {
provider.SAMLProvider.NameIDFormat = nil
} else {
provider.SAMLProvider.NameIDFormat = &params.NameIDFormat
}
}

if modified {
if err := db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Eager().Update(provider); terr != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/models/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ type SAMLProvider struct {

AttributeMapping SAMLAttributeMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"`

NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"`

CreatedAt time.Time `db:"created_at" json:"-"`
UpdatedAt time.Time `db:"updated_at" json:"-"`
}
Expand Down
3 changes: 3 additions & 0 deletions migrations/20240314092811_add_saml_name_id_format.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
do $$ begin
alter table {{ index .Options "Namespace" }}.saml_providers add column if not exists name_id_format text null;
end $$

0 comments on commit fc51dff

Please sign in to comment.