diff --git a/decode_logout_request.go b/decode_logout_request.go
index b1eacc7..57c7eb8 100644
--- a/decode_logout_request.go
+++ b/decode_logout_request.go
@@ -49,7 +49,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutRequestPOST(encodedRequest s
}
// Parse the raw request - parseResponse is generic
- doc, el, err := parseResponse(raw)
+ doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
diff --git a/decode_response.go b/decode_response.go
index 025d1e3..b3258a0 100644
--- a/decode_response.go
+++ b/decode_response.go
@@ -21,15 +21,19 @@ import (
"crypto/x509"
"encoding/base64"
"fmt"
- "io/ioutil"
+ "io"
"encoding/xml"
"github.com/beevik/etree"
+ rtvalidator "github.com/mattermost/xml-roundtrip-validator"
"github.com/russellhaering/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
"github.com/russellhaering/goxmldsig/etreeutils"
- rtvalidator "github.com/mattermost/xml-roundtrip-validator"
+)
+
+const (
+ defaultMaxDecompressedResponseSize = 5 * 1024 * 1024
)
func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext {
@@ -174,7 +178,7 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
return fmt.Errorf("unable to decrypt encrypted assertion: %v", derr)
}
- doc, _, err := parseResponse(raw)
+ doc, _, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return fmt.Errorf("unable to create element from decrypted assertion bytes: %v", err)
}
@@ -250,9 +254,9 @@ func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) er
}
}
-//ValidateEncodedResponse both decodes and validates, based on SP
-//configuration, an encoded, signed response. It will also appropriately
-//decrypt a response if the assertion was encrypted
+// ValidateEncodedResponse both decodes and validates, based on SP
+// configuration, an encoded, signed response. It will also appropriately
+// decrypt a response if the assertion was encrypted
func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (*types.Response, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
@@ -260,7 +264,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (
}
// Parse the raw response
- doc, el, err := parseResponse(raw)
+ doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
@@ -330,7 +334,7 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase
var response *types.UnverifiedBaseResponse
- err = maybeDeflate(raw, func(maybeXML []byte) error {
+ err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.UnverifiedBaseResponse{}
return xml.Unmarshal(maybeXML, response)
})
@@ -344,26 +348,37 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase
// maybeDeflate invokes the passed decoder over the passed data. If an error is
// returned, it then attempts to deflate the passed data before re-invoking
// the decoder over the deflated data.
-func maybeDeflate(data []byte, decoder func([]byte) error) error {
+func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error {
err := decoder(data)
if err == nil {
return nil
}
- deflated, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data)))
+ // Default to 5MB max size
+ if maxSize == 0 {
+ maxSize = defaultMaxDecompressedResponseSize
+ }
+
+ lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1)
+
+ deflated, err := io.ReadAll(lr)
if err != nil {
return err
}
+ if int64(len(deflated)) > maxSize {
+ return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize)
+ }
+
return decoder(deflated)
}
// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested
-func parseResponse(xml []byte) (*etree.Document, *etree.Element, error) {
+func parseResponse(xml []byte, maxSize int64) (*etree.Document, *etree.Element, error) {
var doc *etree.Document
var rawXML []byte
- err := maybeDeflate(xml, func(xml []byte) error {
+ err := maybeDeflate(xml, maxSize, func(xml []byte) error {
doc = etree.NewDocument()
rawXML = xml
return doc.ReadFromBytes(xml)
@@ -395,7 +410,7 @@ func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutRespon
var response *types.LogoutResponse
- err = maybeDeflate(raw, func(maybeXML []byte) error {
+ err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.LogoutResponse{}
return xml.Unmarshal(maybeXML, response)
})
@@ -413,7 +428,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse
}
// Parse the raw response
- doc, el, err := parseResponse(raw)
+ doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
diff --git a/decode_response_test.go b/decode_response_test.go
index 4302976..613f9a6 100644
--- a/decode_response_test.go
+++ b/decode_response_test.go
@@ -25,9 +25,9 @@ import (
"time"
"github.com/jonboulle/clockwork"
- "github.com/russellhaering/goxmldsig"
- "github.com/stretchr/testify/require"
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
+ dsig "github.com/russellhaering/goxmldsig"
+ "github.com/stretchr/testify/require"
)
const (
@@ -169,7 +169,7 @@ func TestDecodeColonsInLocalNames(t *testing.T) {
t.Skip()
}
- _, _, err := parseResponse([]byte(``))
+ _, _, err := parseResponse([]byte(``), 0)
require.Error(t, err)
}
@@ -180,7 +180,7 @@ func TestDecodeDoubleColonInjectionAttackResponse(t *testing.T) {
t.Skip()
}
- _, _, err := parseResponse([]byte(doubleColonAssertionInjectionAttackResponse))
+ _, _, err := parseResponse([]byte(doubleColonAssertionInjectionAttackResponse), 0)
require.Error(t, err)
}
@@ -194,7 +194,7 @@ func TestMalFormedInput(t *testing.T) {
}
sp := &SAMLServiceProvider{
- Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2019, 8, 12, 12, 00, 52, 718, time.UTC))),
+ Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2019, 8, 12, 12, 00, 52, 718, time.UTC))),
AssertionConsumerServiceURL: "https://saml2.test.astuart.co/sso/saml2",
SignAuthnRequests: true,
IDPCertificateStore: &certStore,
@@ -203,4 +203,27 @@ func TestMalFormedInput(t *testing.T) {
base64Input := base64.StdEncoding.EncodeToString([]byte(badInput))
_, err = sp.RetrieveAssertionInfo(base64Input)
require.Errorf(t, err, "parent is nil")
-}
\ No newline at end of file
+}
+
+func TestCompressionBombInput(t *testing.T) {
+ bs, err := ioutil.ReadFile("./testdata/saml_compressed.post")
+ require.NoError(t, err, "couldn't read compressed post")
+
+ block, _ := pem.Decode([]byte(oktaCert))
+
+ idpCert, err := x509.ParseCertificate(block.Bytes)
+ require.NoError(t, err, "couldn't parse okta cert pem block")
+
+ sp := SAMLServiceProvider{
+ AssertionConsumerServiceURL: "https://f1f51ddc.ngrok.io/api/sso/saml2/acs/58cafd0573d4f375b8e70e8e",
+ SPKeyStore: dsig.TLSCertKeyStore(cert),
+ IDPCertificateStore: &dsig.MemoryX509CertificateStore{
+ Roots: []*x509.Certificate{idpCert},
+ },
+ Clock: dsig.NewFakeClock(clockwork.NewFakeClockAt(time.Date(2017, 3, 17, 20, 00, 0, 0, time.UTC))),
+ MaximumDecompressedBodySize: 2048,
+ }
+
+ _, err = sp.RetrieveAssertionInfo(string(bs))
+ require.NoError(t, err, "Assertion info should be retrieved with no error")
+}
diff --git a/saml.go b/saml.go
index 8ae3a3a..adc1489 100644
--- a/saml.go
+++ b/saml.go
@@ -74,8 +74,14 @@ type SAMLServiceProvider struct {
SkipSignatureValidation bool
AllowMissingAttributes bool
Clock *dsig.Clock
- signingContextMu sync.RWMutex
- signingContext *dsig.SigningContext
+
+ // MaximumDecompressedBodySize is the maximum size to which a compressed
+ // SAML document will be decompressed. If a compresed document is exceeds
+ // this size during decompression an error will be returned.
+ MaximumDecompressedBodySize int64
+
+ signingContextMu sync.RWMutex
+ signingContext *dsig.SigningContext
}
// RequestedAuthnContext controls which authentication mechanisms are requested of
diff --git a/saml_test.go b/saml_test.go
index 1cfb359..ac84fe6 100644
--- a/saml_test.go
+++ b/saml_test.go
@@ -353,7 +353,7 @@ func TestSAMLCommentInjection(t *testing.T) {
*/
// To show that we are not vulnerable, we want to prove that we get the canonicalized value using our parser
- _, el, err := parseResponse([]byte(commentInjectionAttackResponse))
+ _, el, err := parseResponse([]byte(commentInjectionAttackResponse), 0)
require.NoError(t, err)
decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)