diff --git a/pkg/provider/logout.go b/pkg/provider/logout.go index a6e59a3..03eaef9 100644 --- a/pkg/provider/logout.go +++ b/pkg/provider/logout.go @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque return nil }, func() { - http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError) + response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat)) }, ) @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque return nil }, func() { - http.Error(w, fmt.Errorf("failed to decode request: %w", err).Error(), http.StatusInternalServerError) + response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat)) }, ) diff --git a/pkg/provider/logout_response.go b/pkg/provider/logout_response.go index c80311e..c666bb0 100644 --- a/pkg/provider/logout_response.go +++ b/pkg/provider/logout_response.go @@ -1,15 +1,12 @@ package provider import ( - "bufio" - "bytes" "encoding/base64" - "encoding/xml" - "fmt" "html/template" "net/http" "time" + "github.com/zitadel/saml/pkg/provider/xml" "github.com/zitadel/saml/pkg/provider/xml/saml" "github.com/zitadel/saml/pkg/provider/xml/samlp" ) @@ -32,39 +29,24 @@ type LogoutResponseForm struct { } func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *samlp.LogoutResponseType) { - var xmlbuff bytes.Buffer - - memWriter := bufio.NewWriter(&xmlbuff) - _, err := memWriter.Write([]byte(xml.Header)) + respData, err := xml.Marshal(resp) if err != nil { r.ErrorFunc(err) return } - encoder := xml.NewEncoder(memWriter) - err = encoder.Encode(resp) - if err != nil { - r.ErrorFunc(err) - return + if r.LogoutURL == "" { + if err := xml.Write(w, respData); err != nil { + r.ErrorFunc(err) + return + } } - err = memWriter.Flush() - if err != nil { - r.ErrorFunc(err) - return - } - samlMessageBytes := xmlbuff.Bytes() - data := LogoutResponseForm{ RelayState: r.RelayState, - SAMLResponse: base64.StdEncoding.EncodeToString(samlMessageBytes), + SAMLResponse: base64.StdEncoding.EncodeToString(respData), LogoutURL: r.LogoutURL, } - if data.LogoutURL == "" { - w.Write(samlMessageBytes) - http.Error(w, fmt.Errorf("failed to find logout url").Error(), http.StatusInternalServerError) - return - } if err := r.LogoutTemplate.Execute(w, data); err != nil { r.ErrorFunc(err) diff --git a/pkg/provider/redirect.go b/pkg/provider/redirect.go index b12ce64..45f76e5 100644 --- a/pkg/provider/redirect.go +++ b/pkg/provider/redirect.go @@ -71,12 +71,12 @@ func createRedirectSignature( idp *IdentityProvider, response *Response, ) error { - respStr, err := xml.Marshal(samlResponse) + resp, err := xml.Marshal(samlResponse) if err != nil { return err } - respData, err := xml.DeflateAndBase64([]byte(respStr)) + respData, err := xml.DeflateAndBase64(resp) if err != nil { return err } diff --git a/pkg/provider/response.go b/pkg/provider/response.go index 5452556..1598f71 100644 --- a/pkg/provider/response.go +++ b/pkg/provider/response.go @@ -41,8 +41,28 @@ type Response struct { } func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) { + +} + +type AuthResponseForm struct { + RelayState string + SAMLResponse string + AssertionConsumerServiceURL string +} + +func (r *Response) sendBackResponse( + req *http.Request, + w http.ResponseWriter, + resp *samlp.ResponseType, +) { + respData, err := xml.Marshal(resp) + if err != nil { + r.ErrorFunc(err) + return + } + if r.AcsUrl == "" { - if err := xml.Write(w, []byte(response)); err != nil { + if err := xml.Write(w, respData); err != nil { r.ErrorFunc(err) return } @@ -50,7 +70,7 @@ func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, resp switch r.ProtocolBinding { case PostBinding: - respData := base64.StdEncoding.EncodeToString([]byte(response)) + respData := base64.StdEncoding.EncodeToString(respData) data := AuthResponseForm{ r.RelayState, @@ -58,49 +78,24 @@ func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, resp r.AcsUrl, } - if data.AssertionConsumerServiceURL == "" { - w.Write([]byte(response)) - http.Error(w, fmt.Errorf("failed to find AssertionConsumerServiceURL").Error(), http.StatusInternalServerError) - return - } if err := r.PostTemplate.Execute(w, data); err != nil { r.ErrorFunc(err) return } case RedirectBinding: - respData, err := xml.DeflateAndBase64([]byte(response)) + respData, err := xml.DeflateAndBase64(respData) if err != nil { r.ErrorFunc(err) return } - http.Redirect(w, request, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound) + http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound) return default: //TODO: no binding } } -type AuthResponseForm struct { - RelayState string - SAMLResponse string - AssertionConsumerServiceURL string -} - -func (r *Response) sendBackResponse( - req *http.Request, - w http.ResponseWriter, - resp *samlp.ResponseType, -) { - respStr, err := xml.Marshal(resp) - if err != nil { - r.ErrorFunc(err) - return - } - - r.doResponse(req, w, respStr) -} - func (r *Response) makeUnsupportedBindingResponse( message string, timeFormat string, diff --git a/pkg/provider/signature/signature_test.go b/pkg/provider/signature/signature_test.go index 766af15..1dad5e1 100644 --- a/pkg/provider/signature/signature_test.go +++ b/pkg/provider/signature/signature_test.go @@ -512,7 +512,7 @@ func TestSignature_CreatePost(t *testing.T) { } resp.Signature = sig - respStr, err := saml_xml.Marshal(resp) + respData, err := saml_xml.Marshal(resp) if err != nil { if (err != nil) != tt.res.err { t.Errorf("Create() marshall response for signing") @@ -521,7 +521,7 @@ func TestSignature_CreatePost(t *testing.T) { } doc := etree.NewDocument() - if err := doc.ReadFromBytes([]byte(respStr)); err != nil { + if err := doc.ReadFromBytes(respData); err != nil { if (err != nil) != tt.res.err { t.Errorf("Cert() failed to read response") } diff --git a/pkg/provider/sso.go b/pkg/provider/sso.go index f521c7f..dafca50 100644 --- a/pkg/provider/sso.go +++ b/pkg/provider/sso.go @@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) return nil }, func() { - http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError) + response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat)) }, ) diff --git a/pkg/provider/sso_test.go b/pkg/provider/sso_test.go index 2ede723..4b68d8c 100644 --- a/pkg/provider/sso_test.go +++ b/pkg/provider/sso_test.go @@ -559,8 +559,8 @@ func TestSSO_ssoHandleFunc(t *testing.T) { }, }, res{ - code: 500, - state: "", + code: 200, + state: StatusCodeRequestDenied, err: false, }}, { diff --git a/pkg/provider/xml/xml.go b/pkg/provider/xml/xml.go index bcab08e..ad4e976 100644 --- a/pkg/provider/xml/xml.go +++ b/pkg/provider/xml/xml.go @@ -19,27 +19,27 @@ const ( EncodingDeflate = "urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE" ) -func Marshal(data interface{}) (string, error) { +func Marshal(data interface{}) ([]byte, error) { var xmlbuff bytes.Buffer memWriter := bufio.NewWriter(&xmlbuff) _, err := memWriter.Write([]byte(xml.Header)) if err != nil { - return "", err + return nil, err } encoder := xml.NewEncoder(memWriter) err = encoder.Encode(data) if err != nil { - return "", err + return nil, err } err = memWriter.Flush() if err != nil { - return "", err + return nil, err } - return xmlbuff.String(), nil + return xmlbuff.Bytes(), nil } func DeflateAndBase64(data []byte) ([]byte, error) { diff --git a/pkg/provider/xml/xml_test.go b/pkg/provider/xml/xml_test.go index dfacf51..86cc341 100644 --- a/pkg/provider/xml/xml_test.go +++ b/pkg/provider/xml/xml_test.go @@ -1,6 +1,7 @@ package xml_test import ( + "slices" "testing" "github.com/zitadel/saml/pkg/provider/xml" @@ -12,7 +13,7 @@ type XML struct { func Test_XmlMarshal(t *testing.T) { type res struct { - metadata string + metadata []byte err bool } @@ -25,7 +26,7 @@ func Test_XmlMarshal(t *testing.T) { name: "xml struct", arg: "", res: res{ - metadata: "\n", + metadata: []byte("\n"), err: false, }, }, @@ -35,13 +36,14 @@ func Test_XmlMarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { xmlStruct := XML{InnerXml: tt.arg} - xmlStr, err := xml.Marshal(xmlStruct) + xmlData, err := xml.Marshal(xmlStruct) if (err != nil) != tt.res.err { t.Errorf("Marshal() error: %v", err) return } - if xmlStr != tt.res.metadata { - t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlStr) + + if !slices.Equal(xmlData, tt.res.metadata) { + t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlData) return } })