Skip to content

Commit

Permalink
fix: change handling of responses without url to respond to
Browse files Browse the repository at this point in the history
  • Loading branch information
stebenz committed Nov 16, 2023
1 parent 7ae5510 commit 689f3f2
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 74 deletions.
4 changes: 2 additions & 2 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
http.Error(w, fmt.Errorf("failed to decode request: %w", err).Error(), http.StatusInternalServerError)
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
},
)

Expand Down
34 changes: 8 additions & 26 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package provider

import (
"bufio"
"bytes"
"encoding/base64"
"encoding/xml"
"fmt"
"html/template"
"net/http"
"time"

"github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/saml/pkg/provider/xml/saml"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)
Expand All @@ -32,39 +29,24 @@ type LogoutResponseForm struct {
}

func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *samlp.LogoutResponseType) {
var xmlbuff bytes.Buffer

memWriter := bufio.NewWriter(&xmlbuff)
_, err := memWriter.Write([]byte(xml.Header))
respData, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

encoder := xml.NewEncoder(memWriter)
err = encoder.Encode(resp)
if err != nil {
r.ErrorFunc(err)
return
if r.LogoutURL == "" {
if err := xml.Write(w, respData); err != nil {
r.ErrorFunc(err)
return
}
}

err = memWriter.Flush()
if err != nil {
r.ErrorFunc(err)
return
}
samlMessageBytes := xmlbuff.Bytes()

data := LogoutResponseForm{
RelayState: r.RelayState,
SAMLResponse: base64.StdEncoding.EncodeToString(samlMessageBytes),
SAMLResponse: base64.StdEncoding.EncodeToString(respData),
LogoutURL: r.LogoutURL,
}
if data.LogoutURL == "" {
w.Write(samlMessageBytes)
http.Error(w, fmt.Errorf("failed to find logout url").Error(), http.StatusInternalServerError)
return
}

if err := r.LogoutTemplate.Execute(w, data); err != nil {
r.ErrorFunc(err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ func createRedirectSignature(
idp *IdentityProvider,
response *Response,
) error {
respStr, err := xml.Marshal(samlResponse)
resp, err := xml.Marshal(samlResponse)
if err != nil {
return err
}

respData, err := xml.DeflateAndBase64([]byte(respStr))
respData, err := xml.DeflateAndBase64(resp)
if err != nil {
return err
}
Expand Down
53 changes: 24 additions & 29 deletions pkg/provider/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,66 +41,61 @@ type Response struct {
}

func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) {

}

type AuthResponseForm struct {
RelayState string
SAMLResponse string
AssertionConsumerServiceURL string
}

func (r *Response) sendBackResponse(
req *http.Request,
w http.ResponseWriter,
resp *samlp.ResponseType,
) {
respData, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

if r.AcsUrl == "" {
if err := xml.Write(w, []byte(response)); err != nil {
if err := xml.Write(w, respData); err != nil {
r.ErrorFunc(err)
return
}
}

switch r.ProtocolBinding {
case PostBinding:
respData := base64.StdEncoding.EncodeToString([]byte(response))
respData := base64.StdEncoding.EncodeToString(respData)

data := AuthResponseForm{
r.RelayState,
respData,
r.AcsUrl,
}

if data.AssertionConsumerServiceURL == "" {
w.Write([]byte(response))
http.Error(w, fmt.Errorf("failed to find AssertionConsumerServiceURL").Error(), http.StatusInternalServerError)
return
}
if err := r.PostTemplate.Execute(w, data); err != nil {
r.ErrorFunc(err)
return
}
case RedirectBinding:
respData, err := xml.DeflateAndBase64([]byte(response))
respData, err := xml.DeflateAndBase64(respData)
if err != nil {
r.ErrorFunc(err)
return
}

http.Redirect(w, request, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
return
default:
//TODO: no binding
}
}

type AuthResponseForm struct {
RelayState string
SAMLResponse string
AssertionConsumerServiceURL string
}

func (r *Response) sendBackResponse(
req *http.Request,
w http.ResponseWriter,
resp *samlp.ResponseType,
) {
respStr, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

r.doResponse(req, w, respStr)
}

func (r *Response) makeUnsupportedBindingResponse(
message string,
timeFormat string,
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/signature/signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ func TestSignature_CreatePost(t *testing.T) {
}
resp.Signature = sig

respStr, err := saml_xml.Marshal(resp)
respData, err := saml_xml.Marshal(resp)
if err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Create() marshall response for signing")
Expand All @@ -521,7 +521,7 @@ func TestSignature_CreatePost(t *testing.T) {
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes([]byte(respStr)); err != nil {
if err := doc.ReadFromBytes(respData); err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Cert() failed to read response")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat))
},
)

Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,8 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
},
res{
code: 500,
state: "",
code: 200,
state: StatusCodeRequestDenied,
err: false,
}},
{
Expand Down
10 changes: 5 additions & 5 deletions pkg/provider/xml/xml.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ const (
EncodingDeflate = "urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE"
)

func Marshal(data interface{}) (string, error) {
func Marshal(data interface{}) ([]byte, error) {
var xmlbuff bytes.Buffer

memWriter := bufio.NewWriter(&xmlbuff)
_, err := memWriter.Write([]byte(xml.Header))
if err != nil {
return "", err
return nil, err
}

encoder := xml.NewEncoder(memWriter)
err = encoder.Encode(data)
if err != nil {
return "", err
return nil, err
}

err = memWriter.Flush()
if err != nil {
return "", err
return nil, err
}

return xmlbuff.String(), nil
return xmlbuff.Bytes(), nil
}

func DeflateAndBase64(data []byte) ([]byte, error) {
Expand Down
12 changes: 7 additions & 5 deletions pkg/provider/xml/xml_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package xml_test

import (
"slices"
"testing"

"github.com/zitadel/saml/pkg/provider/xml"
Expand All @@ -12,7 +13,7 @@ type XML struct {

func Test_XmlMarshal(t *testing.T) {
type res struct {
metadata string
metadata []byte
err bool
}

Expand All @@ -25,7 +26,7 @@ func Test_XmlMarshal(t *testing.T) {
name: "xml struct",
arg: "<test></test>",
res: res{
metadata: "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<XML><test></test></XML>",
metadata: []byte("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<XML><test></test></XML>"),
err: false,
},
},
Expand All @@ -35,13 +36,14 @@ func Test_XmlMarshal(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
xmlStruct := XML{InnerXml: tt.arg}

xmlStr, err := xml.Marshal(xmlStruct)
xmlData, err := xml.Marshal(xmlStruct)
if (err != nil) != tt.res.err {
t.Errorf("Marshal() error: %v", err)
return
}
if xmlStr != tt.res.metadata {
t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlStr)

if !slices.Equal(xmlData, tt.res.metadata) {
t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlData)
return
}
})
Expand Down

0 comments on commit 689f3f2

Please sign in to comment.