diff --git a/pkg/provider/logout.go b/pkg/provider/logout.go
index a6e59a3..03eaef9 100644
--- a/pkg/provider/logout.go
+++ b/pkg/provider/logout.go
@@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
- http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
+ response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
},
)
@@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
- http.Error(w, fmt.Errorf("failed to decode request: %w", err).Error(), http.StatusInternalServerError)
+ response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
},
)
diff --git a/pkg/provider/logout_response.go b/pkg/provider/logout_response.go
index c80311e..c666bb0 100644
--- a/pkg/provider/logout_response.go
+++ b/pkg/provider/logout_response.go
@@ -1,15 +1,12 @@
package provider
import (
- "bufio"
- "bytes"
"encoding/base64"
- "encoding/xml"
- "fmt"
"html/template"
"net/http"
"time"
+ "github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/saml/pkg/provider/xml/saml"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)
@@ -32,39 +29,24 @@ type LogoutResponseForm struct {
}
func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *samlp.LogoutResponseType) {
- var xmlbuff bytes.Buffer
-
- memWriter := bufio.NewWriter(&xmlbuff)
- _, err := memWriter.Write([]byte(xml.Header))
+ respData, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}
- encoder := xml.NewEncoder(memWriter)
- err = encoder.Encode(resp)
- if err != nil {
- r.ErrorFunc(err)
- return
+ if r.LogoutURL == "" {
+ if err := xml.Write(w, respData); err != nil {
+ r.ErrorFunc(err)
+ return
+ }
}
- err = memWriter.Flush()
- if err != nil {
- r.ErrorFunc(err)
- return
- }
- samlMessageBytes := xmlbuff.Bytes()
-
data := LogoutResponseForm{
RelayState: r.RelayState,
- SAMLResponse: base64.StdEncoding.EncodeToString(samlMessageBytes),
+ SAMLResponse: base64.StdEncoding.EncodeToString(respData),
LogoutURL: r.LogoutURL,
}
- if data.LogoutURL == "" {
- w.Write(samlMessageBytes)
- http.Error(w, fmt.Errorf("failed to find logout url").Error(), http.StatusInternalServerError)
- return
- }
if err := r.LogoutTemplate.Execute(w, data); err != nil {
r.ErrorFunc(err)
diff --git a/pkg/provider/redirect.go b/pkg/provider/redirect.go
index b12ce64..45f76e5 100644
--- a/pkg/provider/redirect.go
+++ b/pkg/provider/redirect.go
@@ -71,12 +71,12 @@ func createRedirectSignature(
idp *IdentityProvider,
response *Response,
) error {
- respStr, err := xml.Marshal(samlResponse)
+ resp, err := xml.Marshal(samlResponse)
if err != nil {
return err
}
- respData, err := xml.DeflateAndBase64([]byte(respStr))
+ respData, err := xml.DeflateAndBase64(resp)
if err != nil {
return err
}
diff --git a/pkg/provider/response.go b/pkg/provider/response.go
index 5452556..1598f71 100644
--- a/pkg/provider/response.go
+++ b/pkg/provider/response.go
@@ -41,8 +41,28 @@ type Response struct {
}
func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) {
+
+}
+
+type AuthResponseForm struct {
+ RelayState string
+ SAMLResponse string
+ AssertionConsumerServiceURL string
+}
+
+func (r *Response) sendBackResponse(
+ req *http.Request,
+ w http.ResponseWriter,
+ resp *samlp.ResponseType,
+) {
+ respData, err := xml.Marshal(resp)
+ if err != nil {
+ r.ErrorFunc(err)
+ return
+ }
+
if r.AcsUrl == "" {
- if err := xml.Write(w, []byte(response)); err != nil {
+ if err := xml.Write(w, respData); err != nil {
r.ErrorFunc(err)
return
}
@@ -50,7 +70,7 @@ func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, resp
switch r.ProtocolBinding {
case PostBinding:
- respData := base64.StdEncoding.EncodeToString([]byte(response))
+ respData := base64.StdEncoding.EncodeToString(respData)
data := AuthResponseForm{
r.RelayState,
@@ -58,49 +78,24 @@ func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, resp
r.AcsUrl,
}
- if data.AssertionConsumerServiceURL == "" {
- w.Write([]byte(response))
- http.Error(w, fmt.Errorf("failed to find AssertionConsumerServiceURL").Error(), http.StatusInternalServerError)
- return
- }
if err := r.PostTemplate.Execute(w, data); err != nil {
r.ErrorFunc(err)
return
}
case RedirectBinding:
- respData, err := xml.DeflateAndBase64([]byte(response))
+ respData, err := xml.DeflateAndBase64(respData)
if err != nil {
r.ErrorFunc(err)
return
}
- http.Redirect(w, request, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
+ http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
return
default:
//TODO: no binding
}
}
-type AuthResponseForm struct {
- RelayState string
- SAMLResponse string
- AssertionConsumerServiceURL string
-}
-
-func (r *Response) sendBackResponse(
- req *http.Request,
- w http.ResponseWriter,
- resp *samlp.ResponseType,
-) {
- respStr, err := xml.Marshal(resp)
- if err != nil {
- r.ErrorFunc(err)
- return
- }
-
- r.doResponse(req, w, respStr)
-}
-
func (r *Response) makeUnsupportedBindingResponse(
message string,
timeFormat string,
diff --git a/pkg/provider/signature/signature_test.go b/pkg/provider/signature/signature_test.go
index 766af15..1dad5e1 100644
--- a/pkg/provider/signature/signature_test.go
+++ b/pkg/provider/signature/signature_test.go
@@ -512,7 +512,7 @@ func TestSignature_CreatePost(t *testing.T) {
}
resp.Signature = sig
- respStr, err := saml_xml.Marshal(resp)
+ respData, err := saml_xml.Marshal(resp)
if err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Create() marshall response for signing")
@@ -521,7 +521,7 @@ func TestSignature_CreatePost(t *testing.T) {
}
doc := etree.NewDocument()
- if err := doc.ReadFromBytes([]byte(respStr)); err != nil {
+ if err := doc.ReadFromBytes(respData); err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Cert() failed to read response")
}
diff --git a/pkg/provider/sso.go b/pkg/provider/sso.go
index f521c7f..dafca50 100644
--- a/pkg/provider/sso.go
+++ b/pkg/provider/sso.go
@@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
- http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
+ response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat))
},
)
diff --git a/pkg/provider/sso_test.go b/pkg/provider/sso_test.go
index 2ede723..4b68d8c 100644
--- a/pkg/provider/sso_test.go
+++ b/pkg/provider/sso_test.go
@@ -559,8 +559,8 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
},
res{
- code: 500,
- state: "",
+ code: 200,
+ state: StatusCodeRequestDenied,
err: false,
}},
{
diff --git a/pkg/provider/xml/xml.go b/pkg/provider/xml/xml.go
index bcab08e..ad4e976 100644
--- a/pkg/provider/xml/xml.go
+++ b/pkg/provider/xml/xml.go
@@ -19,27 +19,27 @@ const (
EncodingDeflate = "urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE"
)
-func Marshal(data interface{}) (string, error) {
+func Marshal(data interface{}) ([]byte, error) {
var xmlbuff bytes.Buffer
memWriter := bufio.NewWriter(&xmlbuff)
_, err := memWriter.Write([]byte(xml.Header))
if err != nil {
- return "", err
+ return nil, err
}
encoder := xml.NewEncoder(memWriter)
err = encoder.Encode(data)
if err != nil {
- return "", err
+ return nil, err
}
err = memWriter.Flush()
if err != nil {
- return "", err
+ return nil, err
}
- return xmlbuff.String(), nil
+ return xmlbuff.Bytes(), nil
}
func DeflateAndBase64(data []byte) ([]byte, error) {
diff --git a/pkg/provider/xml/xml_test.go b/pkg/provider/xml/xml_test.go
index dfacf51..86cc341 100644
--- a/pkg/provider/xml/xml_test.go
+++ b/pkg/provider/xml/xml_test.go
@@ -1,6 +1,7 @@
package xml_test
import (
+ "slices"
"testing"
"github.com/zitadel/saml/pkg/provider/xml"
@@ -12,7 +13,7 @@ type XML struct {
func Test_XmlMarshal(t *testing.T) {
type res struct {
- metadata string
+ metadata []byte
err bool
}
@@ -25,7 +26,7 @@ func Test_XmlMarshal(t *testing.T) {
name: "xml struct",
arg: "",
res: res{
- metadata: "\n",
+ metadata: []byte("\n"),
err: false,
},
},
@@ -35,13 +36,14 @@ func Test_XmlMarshal(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
xmlStruct := XML{InnerXml: tt.arg}
- xmlStr, err := xml.Marshal(xmlStruct)
+ xmlData, err := xml.Marshal(xmlStruct)
if (err != nil) != tt.res.err {
t.Errorf("Marshal() error: %v", err)
return
}
- if xmlStr != tt.res.metadata {
- t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlStr)
+
+ if !slices.Equal(xmlData, tt.res.metadata) {
+ t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlData)
return
}
})