Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: corrected inflate and deflate logic for xml, and corresponding tests with some refactoring for SAML sessions #92

Merged
merged 7 commits into from
Dec 11, 2024
8 changes: 6 additions & 2 deletions pkg/provider/attribute_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat)
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.TimeFormat, p.Expiration)
return nil
},
func() {
Expand All @@ -139,7 +139,11 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
// create enveloped signature
checkerInstance.WithLogicStep(
func() error {
return createPostSignature(r.Context(), response, p)
cert, key, err := getResponseCert(r.Context(), p.storage)
if err != nil {
return err
}
return createPostSignature(response, key, cert, p.conf.SignatureAlgorithm)
},
func() {
http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError)
Expand Down
8 changes: 5 additions & 3 deletions pkg/provider/identityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ type IdentityProvider struct {
metadataEndpoint *Endpoint
endpoints *Endpoints

timeFormat string
TimeFormat string
Expiration time.Duration
}

type Endpoints struct {
Expand Down Expand Up @@ -97,7 +98,8 @@ func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storag
postTemplate: postTemplate,
logoutTemplate: logoutTemplate,
endpoints: endpointConfigToEndpoints(conf.Endpoints),
timeFormat: DefaultTimeFormat,
TimeFormat: DefaultTimeFormat,
Expiration: DefaultExpiration,
}

if conf.MetadataIDPConfig == nil {
Expand Down Expand Up @@ -153,7 +155,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto
return nil, nil, err
}

metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat)
metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.TimeFormat)
return metadata, aaMetadata, nil
}

Expand Down
59 changes: 34 additions & 25 deletions pkg/provider/login.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package provider

import (
"context"
"fmt"
"net/http"

"github.com/zitadel/logging"

"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)

func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Request) {
Expand All @@ -16,7 +20,6 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
Issuer: p.GetEntityID(r.Context()),
}

ctx := r.Context()
if err := r.ParseForm(); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
Expand All @@ -34,52 +37,58 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID)
if err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat))
response.sendBackResponse(r, w, p.errorResponse(response, StatusCodeRequestDenied, fmt.Errorf("failed to get request: %w", err).Error()))
return
}
response.RequestID = authRequest.GetAuthRequestID()
response.RelayState = authRequest.GetRelayState()
response.ProtocolBinding = authRequest.GetBindingType()
response.AcsUrl = authRequest.GetAccessConsumerServiceURL()

if !authRequest.Done() {
entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
return
}
response.Audience = entityID

entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
samlResponse, err := p.loginResponse(r.Context(), authRequest, response)
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat))
return
}
response.Audience = entityID

response.sendBackResponse(r, w, samlResponse)
return
}

func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
if !authRequest.Done() {
logging.Error(StatusCodeAuthNFailed)
return nil, fmt.Errorf(StatusCodeAuthNFailed)
}

attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
return
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat)
cert, key, err := getResponseCert(ctx, p.storage)
if err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

switch response.ProtocolBinding {
case PostBinding:
if err := createPostSignature(r.Context(), samlResponse, p); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
case RedirectBinding:
if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration)
if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeResponder)
}
return samlResponse, nil
}

response.sendBackResponse(r, w, samlResponse)
return
func (p *IdentityProvider) errorResponse(response *Response, reason string, description string) *samlp.ResponseType {
return response.makeFailedResponse(reason, description, p.TimeFormat)
}
33 changes: 19 additions & 14 deletions pkg/provider/login_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package provider

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang/mock/gomock"
Expand All @@ -23,9 +23,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
Done bool
}
type res struct {
code int
err bool
state string
code int
err bool
state string
inflate bool
b64 bool
}
type sp struct {
appID string
Expand Down Expand Up @@ -235,7 +237,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
ID: "test",
AuthRequestID: "test",
Binding: RedirectBinding,
AcsURL: "url",
AcsURL: "https://sp.example.com",
RelayState: "relaystate",
UserID: "userid",
Done: false,
Expand All @@ -247,9 +249,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
},
},
res{
code: 500,
state: "",
err: false,
code: 302,
state: StatusCodeAuthNFailed,
err: false,
inflate: true,
b64: true,
}},
}

Expand Down Expand Up @@ -297,14 +301,15 @@ func TestSSO_loginHandleFunc(t *testing.T) {
defer func() {
_ = res.Body.Close()
}()
response, err := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.res.code {
t.Errorf("ssoHandleFunc() code got = %v, want %v", res.StatusCode, tt.res)
return
}

// currently only checked for redirect binding
if tt.res.state != "" {
if err := parseForState(string(response), tt.res.state); err != nil {
responseURL, err := url.Parse(res.Header.Get("Location"))
if err != nil {
t.Errorf("error while parsing url")
}

if err := parseForState(tt.res.inflate, tt.res.b64, responseURL.Query().Get("SAMLResponse"), tt.res.state); err != nil {
t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state)
return
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -69,10 +69,10 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
checkIfRequestTimeIsStillValid(
func() string { return logoutRequest.IssueInstant },
func() string { return logoutRequest.NotOnOrAfter },
p.timeFormat,
p.TimeFormat,
),
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -106,7 +106,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque

response.sendBackLogoutResponse(
w,
response.makeSuccessfulLogoutResponse(p.timeFormat),
response.makeSuccessfulLogoutResponse(p.TimeFormat),
)
logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text))
}
Expand Down
39 changes: 6 additions & 33 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,55 +55,28 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam
}
}

func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeUnsupportedlLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestUnsupported,
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makePartialLogoutResponse(
func (r *LogoutResponse) makeFailedLogoutResponse(
reason string,
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodePartialLogout,
reason,
message,
getIssuer(r.Issuer),
)
}

func (r *LogoutResponse) makeDeniedLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
StatusCodeRequestDenied,
message,
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
)
}
Expand Down
14 changes: 5 additions & 9 deletions pkg/provider/post.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package provider

import (
"context"
"crypto/rsa"
"encoding/base64"
"reflect"

Expand Down Expand Up @@ -63,16 +63,12 @@ func verifyPostSignature(
}

func createPostSignature(
ctx context.Context,
samlResponse *samlp.ResponseType,
idp *IdentityProvider,
key *rsa.PrivateKey,
cert []byte,
signatureAlgorithm string,
) error {
cert, key, err := getResponseCert(ctx, idp.storage)
if err != nil {
return err
}

signer, err := signature.GetSigner(cert, key, idp.conf.SignatureAlgorithm)
signer, err := signature.GetSigner(cert, key, signatureAlgorithm)
if err != nil {
return err
}
Expand Down
Loading