Skip to content

Commit

Permalink
Configurable deflate-bomb protection
Browse files Browse the repository at this point in the history
  • Loading branch information
russellhaering committed Mar 1, 2023
1 parent 156f1b9 commit 56f4c23
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 24 deletions.
2 changes: 1 addition & 1 deletion decode_logout_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutRequestPOST(encodedRequest s
}

// Parse the raw request - parseResponse is generic
doc, el, err := parseResponse(raw)
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
Expand Down
43 changes: 29 additions & 14 deletions decode_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ import (
"crypto/x509"
"encoding/base64"
"fmt"
"io/ioutil"
"io"

"encoding/xml"

"github.com/beevik/etree"
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
"github.com/russellhaering/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
"github.com/russellhaering/goxmldsig/etreeutils"
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
)

const (
defaultMaxDecompressedResponseSize = 5 * 1024 * 1024
)

func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext {
Expand Down Expand Up @@ -174,7 +178,7 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
return fmt.Errorf("unable to decrypt encrypted assertion: %v", derr)
}

doc, _, err := parseResponse(raw)
doc, _, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return fmt.Errorf("unable to create element from decrypted assertion bytes: %v", err)
}
Expand Down Expand Up @@ -250,17 +254,17 @@ func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) er
}
}

//ValidateEncodedResponse both decodes and validates, based on SP
//configuration, an encoded, signed response. It will also appropriately
//decrypt a response if the assertion was encrypted
// ValidateEncodedResponse both decodes and validates, based on SP
// configuration, an encoded, signed response. It will also appropriately
// decrypt a response if the assertion was encrypted
func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (*types.Response, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
return nil, err
}

// Parse the raw response
doc, el, err := parseResponse(raw)
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -330,7 +334,7 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase

var response *types.UnverifiedBaseResponse

err = maybeDeflate(raw, func(maybeXML []byte) error {
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.UnverifiedBaseResponse{}
return xml.Unmarshal(maybeXML, response)
})
Expand All @@ -344,26 +348,37 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase
// maybeDeflate invokes the passed decoder over the passed data. If an error is
// returned, it then attempts to deflate the passed data before re-invoking
// the decoder over the deflated data.
func maybeDeflate(data []byte, decoder func([]byte) error) error {
func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error {
err := decoder(data)
if err == nil {
return nil
}

deflated, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data)))
// Default to 5MB max size
if maxSize == 0 {
maxSize = defaultMaxDecompressedResponseSize
}

lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1)

deflated, err := io.ReadAll(lr)
if err != nil {
return err
}

if int64(len(deflated)) > maxSize {
return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize)
}

return decoder(deflated)
}

// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested
func parseResponse(xml []byte) (*etree.Document, *etree.Element, error) {
func parseResponse(xml []byte, maxSize int64) (*etree.Document, *etree.Element, error) {
var doc *etree.Document
var rawXML []byte

err := maybeDeflate(xml, func(xml []byte) error {
err := maybeDeflate(xml, maxSize, func(xml []byte) error {
doc = etree.NewDocument()
rawXML = xml
return doc.ReadFromBytes(xml)
Expand Down Expand Up @@ -395,7 +410,7 @@ func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutRespon

var response *types.LogoutResponse

err = maybeDeflate(raw, func(maybeXML []byte) error {
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.LogoutResponse{}
return xml.Unmarshal(maybeXML, response)
})
Expand All @@ -413,7 +428,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse
}

// Parse the raw response
doc, el, err := parseResponse(raw)
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
Expand Down
35 changes: 29 additions & 6 deletions decode_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import (
"time"

"github.com/jonboulle/clockwork"
"github.com/russellhaering/goxmldsig"
"github.com/stretchr/testify/require"
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
dsig "github.com/russellhaering/goxmldsig"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestDecodeColonsInLocalNames(t *testing.T) {
t.Skip()
}

_, _, err := parseResponse([]byte(`<x::Root/>`))
_, _, err := parseResponse([]byte(`<x::Root/>`), 0)
require.Error(t, err)
}

Expand All @@ -180,7 +180,7 @@ func TestDecodeDoubleColonInjectionAttackResponse(t *testing.T) {
t.Skip()
}

_, _, err := parseResponse([]byte(doubleColonAssertionInjectionAttackResponse))
_, _, err := parseResponse([]byte(doubleColonAssertionInjectionAttackResponse), 0)
require.Error(t, err)
}

Expand All @@ -194,7 +194,7 @@ func TestMalFormedInput(t *testing.T) {
}

sp := &SAMLServiceProvider{
Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2019, 8, 12, 12, 00, 52, 718, time.UTC))),
Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2019, 8, 12, 12, 00, 52, 718, time.UTC))),
AssertionConsumerServiceURL: "https://saml2.test.astuart.co/sso/saml2",
SignAuthnRequests: true,
IDPCertificateStore: &certStore,
Expand All @@ -203,4 +203,27 @@ func TestMalFormedInput(t *testing.T) {
base64Input := base64.StdEncoding.EncodeToString([]byte(badInput))
_, err = sp.RetrieveAssertionInfo(base64Input)
require.Errorf(t, err, "parent is nil")
}
}

func TestCompressionBombInput(t *testing.T) {
bs, err := ioutil.ReadFile("./testdata/saml_compressed.post")
require.NoError(t, err, "couldn't read compressed post")

block, _ := pem.Decode([]byte(oktaCert))

idpCert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err, "couldn't parse okta cert pem block")

sp := SAMLServiceProvider{
AssertionConsumerServiceURL: "https://f1f51ddc.ngrok.io/api/sso/saml2/acs/58cafd0573d4f375b8e70e8e",
SPKeyStore: dsig.TLSCertKeyStore(cert),
IDPCertificateStore: &dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{idpCert},
},
Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2017, 3, 17, 20, 00, 0, 0, time.UTC))),
MaximumDecompressedBodySize: 2048,
}

_, err = sp.RetrieveAssertionInfo(string(bs))
require.NoError(t, err, "Assertion info should be retrieved with no error")
}
10 changes: 8 additions & 2 deletions saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ type SAMLServiceProvider struct {
SkipSignatureValidation bool
AllowMissingAttributes bool
Clock *dsig.Clock
signingContextMu sync.RWMutex
signingContext *dsig.SigningContext

// MaximumDecompressedBodySize is the maximum size to which a compressed
// SAML document will be decompressed. If a compresed document is exceeds
// this size during decompression an error will be returned.
MaximumDecompressedBodySize int64

signingContextMu sync.RWMutex
signingContext *dsig.SigningContext
}

// RequestedAuthnContext controls which authentication mechanisms are requested of
Expand Down
2 changes: 1 addition & 1 deletion saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func TestSAMLCommentInjection(t *testing.T) {
*/

// To show that we are not vulnerable, we want to prove that we get the canonicalized value using our parser
_, el, err := parseResponse([]byte(commentInjectionAttackResponse))
_, el, err := parseResponse([]byte(commentInjectionAttackResponse), 0)
require.NoError(t, err)
decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
Expand Down

0 comments on commit 56f4c23

Please sign in to comment.