diff --git a/etree.go b/etree.go
index c3bb746..07e1a6d 100644
--- a/etree.go
+++ b/etree.go
@@ -50,6 +50,13 @@ type ReadSettings struct {
// preserve them instead of keeping only one. Default: false.
PreserveDuplicateAttrs bool
+ // ValidateInput forces all ReadFrom* methods to validate that the
+ // provided input is composed of well-formed XML before processing it. If
+ // invalid XML is detected, the ReadFrom* methods return an error. Because
+ // this option requires the input to be processed twice, it incurs a
+ // significant performance penalty. Default: false.
+ ValidateInput bool
+
// Entity to be passed to standard xml.Decoder. Default: nil.
Entity map[string]string
@@ -66,9 +73,6 @@ func newReadSettings() ReadSettings {
CharsetReader: func(label string, input io.Reader) (io.Reader, error) {
return input, nil
},
- Permissive: false,
- PreserveCData: false,
- Entity: nil,
}
}
@@ -353,6 +357,11 @@ func (d *Document) SetRoot(e *Element) {
// ReadFrom reads XML from the reader 'r' into this document. The function
// returns the number of bytes read and any error encountered.
func (d *Document) ReadFrom(r io.Reader) (n int64, err error) {
+ if d.ReadSettings.ValidateInput {
+ if err := validateXML(r, d.ReadSettings); err != nil {
+ return 0, err
+ }
+ }
return d.Element.readFrom(r, d.ReadSettings)
}
@@ -380,6 +389,35 @@ func (d *Document) ReadFromString(s string) error {
return err
}
+// validateXML determines if the data read from the reader 'r' contains
+// well-formed XML according to the rules set by the go xml package.
+func validateXML(r io.Reader, settings ReadSettings) error {
+ dec := newDecoder(r, settings)
+ err := dec.Decode(new(interface{}))
+ if err != nil {
+ return err
+ }
+
+ // If there are any trailing tokens after unmarshalling with Decode(),
+ // then the XML input didn't terminate properly.
+ _, err = dec.Token()
+ if err == io.EOF {
+ return nil
+ }
+ return ErrXML
+}
+
+// newDecoder creates an XML decoder for the reader 'r' configured using
+// the provided read settings.
+func newDecoder(r io.Reader, settings ReadSettings) *xml.Decoder {
+ d := xml.NewDecoder(r)
+ d.CharsetReader = settings.CharsetReader
+ d.Strict = !settings.Permissive
+ d.Entity = settings.Entity
+ d.AutoClose = settings.AutoClose
+ return d
+}
+
// WriteTo serializes the document out to the writer 'w'. The function returns
// the number of bytes written and any error encountered.
func (d *Document) WriteTo(w io.Writer) (n int64, err error) {
@@ -835,10 +873,7 @@ func (e *Element) readFrom(ri io.Reader, settings ReadSettings) (n int64, err er
r = newXmlSimpleReader(ri)
}
- dec := xml.NewDecoder(r)
- dec.CharsetReader = settings.CharsetReader
- dec.Strict = !settings.Permissive
- dec.Entity = settings.Entity
+ dec := newDecoder(r, settings)
var stack stack
stack.push(e)
diff --git a/etree_test.go b/etree_test.go
index 11c0eea..57c991e 100644
--- a/etree_test.go
+++ b/etree_test.go
@@ -1524,3 +1524,35 @@ func TestNotNil(t *testing.T) {
t.Error("got:\n" + got)
}
}
+
+func TestValidateInput(t *testing.T) {
+ tests := []struct {
+ s string
+ err string
+ }{
+ {`x`, ""},
+ {``, ""},
+ {`x`, `XML syntax error on line 1: unexpected EOF`},
+ {``, `XML syntax error on line 1: unexpected end element `},
+ {`<>`, `XML syntax error on line 1: expected element name after <`},
+ {`xtrailing`, "etree: invalid XML format"},
+ {`x<`, "etree: invalid XML format"},
+ {`x`, `XML syntax error on line 1: element closed by `},
+ }
+
+ for i, test := range tests {
+ doc := NewDocument()
+ doc.ReadSettings.ValidateInput = true
+ err := doc.ReadFromString(test.s)
+ if err == nil {
+ if test.err != "" {
+ t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err)
+ }
+ } else {
+ te := err.Error()
+ if te != test.err {
+ t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te)
+ }
+ }
+ }
+}