Skip to content

Commit 98aeb29

Browse files
authored
Merge pull request #636 from dubinsky/xmlloader-with-xmlreader
2 parents 3f39bdd + 1cc9ea8 commit 98aeb29

File tree

5 files changed

+84
-32
lines changed

5 files changed

+84
-32
lines changed

jvm/src/test/scala/scala/xml/XMLTest.scala

+28
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,34 @@ class XMLTestJVM {
657657
def namespaceAware2: Unit =
658658
roundtrip(namespaceAware = true, """<book xmlns="http://docbook.org/ns/docbook" xmlns:xi="http://www.w3.org/2001/XInclude"><svg xmlns:svg="http://www.w3.org/2000/svg"/></book>""")
659659

660+
@UnitTest
661+
def useXMLReaderWithXMLFilter(): Unit = {
662+
val parent: org.xml.sax.XMLReader = javax.xml.parsers.SAXParserFactory.newInstance.newSAXParser.getXMLReader
663+
val filter: org.xml.sax.XMLFilter = new org.xml.sax.helpers.XMLFilterImpl(parent) {
664+
override def characters(ch: Array[Char], start: Int, length: Int): Unit = {
665+
for (i <- 0 until length) if (ch(start+i) == 'a') ch(start+i) = 'b'
666+
super.characters(ch, start, length)
667+
}
668+
}
669+
assertEquals(XML.withXMLReader(filter).loadString("<a>caffeeaaay</a>").toString, "<a>cbffeebbby</a>")
670+
}
671+
672+
@UnitTest
673+
def checkThatErrorHandlerIsNotOverwritten(): Unit = {
674+
var gotAnError: Boolean = false
675+
XML.reader.setErrorHandler(new org.xml.sax.ErrorHandler {
676+
override def warning(e: SAXParseException): Unit = gotAnError = true
677+
override def error(e: SAXParseException): Unit = gotAnError = true
678+
override def fatalError(e: SAXParseException): Unit = gotAnError = true
679+
})
680+
try {
681+
XML.loadString("<a>")
682+
} catch {
683+
case _: org.xml.sax.SAXParseException =>
684+
}
685+
assertTrue(gotAnError)
686+
}
687+
660688
@UnitTest
661689
def nodeSeqNs: Unit = {
662690
val x = {

shared/src/main/scala/scala/xml/XML.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ package scala
1414
package xml
1515

1616
import factory.XMLLoader
17-
import java.io.{ File, FileDescriptor, FileInputStream, FileOutputStream }
18-
import java.io.{ InputStream, Reader, StringReader }
17+
import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream}
18+
import java.io.{InputStream, Reader, StringReader}
1919
import java.nio.channels.Channels
2020
import scala.util.control.Exception.ultimately
2121

@@ -72,6 +72,10 @@ object XML extends XMLLoader[Elem] {
7272
def withSAXParser(p: SAXParser): XMLLoader[Elem] =
7373
new XMLLoader[Elem] { override val parser: SAXParser = p }
7474

75+
/** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */
76+
def withXMLReader(r: XMLReader): XMLLoader[Elem] =
77+
new XMLLoader[Elem] { override val reader: XMLReader = r }
78+
7579
/**
7680
* Saves a node to a file with given filename using given encoding
7781
* optionally with xmldecl and doctype declaration.

shared/src/main/scala/scala/xml/factory/XMLLoader.scala

+46-28
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ package scala
1414
package xml
1515
package factory
1616

17-
import org.xml.sax.SAXNotRecognizedException
17+
import org.xml.sax.{SAXNotRecognizedException, XMLReader}
1818
import javax.xml.parsers.SAXParserFactory
1919
import parsing.{FactoryAdapter, NoBindingFactoryAdapter}
2020
import java.io.{File, FileDescriptor, InputStream, Reader}
@@ -46,59 +46,77 @@ trait XMLLoader[T <: Node] {
4646
/* Override this to use a different SAXParser. */
4747
def parser: SAXParser = parserInstance.get
4848

49+
/* Override this to use a different XMLReader. */
50+
def reader: XMLReader = parser.getXMLReader
51+
4952
/**
5053
* Loads XML from the given InputSource, using the supplied parser.
5154
* The methods available in scala.xml.XML use the XML parser in the JDK.
5255
*/
53-
def loadXML(source: InputSource, parser: SAXParser): T = {
54-
val result: FactoryAdapter = parse(source, parser)
56+
def loadXML(source: InputSource, parser: SAXParser): T = loadXML(source, parser.getXMLReader)
57+
58+
def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(source, parser.getXMLReader)
59+
60+
private def loadXML(source: InputSource, reader: XMLReader): T = {
61+
val result: FactoryAdapter = parse(source, reader)
5562
result.rootElem.asInstanceOf[T]
5663
}
57-
58-
def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = {
59-
val result: FactoryAdapter = parse(source, parser)
64+
65+
private def loadXMLNodes(source: InputSource, reader: XMLReader): Seq[Node] = {
66+
val result: FactoryAdapter = parse(source, reader)
6067
result.prolog ++ (result.rootElem :: result.epilogue)
6168
}
6269

63-
private def parse(source: InputSource, parser: SAXParser): FactoryAdapter = {
70+
private def parse(source: InputSource, reader: XMLReader): FactoryAdapter = {
71+
if (source == null) throw new IllegalArgumentException("InputSource cannot be null")
72+
6473
val result: FactoryAdapter = adapter
6574

75+
reader.setContentHandler(result)
76+
reader.setDTDHandler(result)
77+
/* Do not overwrite pre-configured EntityResolver. */
78+
if (reader.getEntityResolver == null) reader.setEntityResolver(result)
79+
/* Do not overwrite pre-configured ErrorHandler. */
80+
if (reader.getErrorHandler == null) reader.setErrorHandler(result)
81+
6682
try {
67-
parser.setProperty("http://xml.org/sax/properties/lexical-handler", result)
83+
reader.setProperty("http://xml.org/sax/properties/lexical-handler", result)
6884
} catch {
6985
case _: SAXNotRecognizedException =>
7086
}
7187

7288
result.scopeStack = TopScope :: result.scopeStack
73-
parser.parse(source, result)
89+
reader.parse(source)
7490
result.scopeStack = result.scopeStack.tail
7591

7692
result
7793
}
7894

95+
/** loads XML from given InputSource. */
96+
def load(source: InputSource): T = loadXML(source, reader)
97+
7998
/** Loads XML from the given file, file descriptor, or filename. */
80-
def loadFile(file: File): T = loadXML(fromFile(file), parser)
81-
def loadFile(fd: FileDescriptor): T = loadXML(fromFile(fd), parser)
82-
def loadFile(name: String): T = loadXML(fromFile(name), parser)
99+
def loadFile(file: File): T = load(fromFile(file))
100+
def loadFile(fd: FileDescriptor): T = load(fromFile(fd))
101+
def loadFile(name: String): T = load(fromFile(name))
83102

84-
/** loads XML from given InputStream, Reader, sysID, InputSource, or URL. */
85-
def load(is: InputStream): T = loadXML(fromInputStream(is), parser)
86-
def load(reader: Reader): T = loadXML(fromReader(reader), parser)
87-
def load(sysID: String): T = loadXML(fromSysId(sysID), parser)
88-
def load(source: InputSource): T = loadXML(source, parser)
89-
def load(url: URL): T = loadXML(fromInputStream(url.openStream()), parser)
103+
/** loads XML from given InputStream, Reader, sysID, or URL. */
104+
def load(is: InputStream): T = load(fromInputStream(is))
105+
def load(reader: Reader): T = load(fromReader(reader))
106+
def load(sysID: String): T = load(fromSysId(sysID))
107+
def load(url: URL): T = load(fromInputStream(url.openStream()))
90108

91109
/** Loads XML from the given String. */
92-
def loadString(string: String): T = loadXML(fromString(string), parser)
110+
def loadString(string: String): T = load(fromString(string))
93111

94112
/** Load XML nodes, including comments and processing instructions that precede and follow the root element. */
95-
def loadFileNodes(file: File): Seq[Node] = loadXMLNodes(fromFile(file), parser)
96-
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadXMLNodes(fromFile(fd), parser)
97-
def loadFileNodes(name: String): Seq[Node] = loadXMLNodes(fromFile(name), parser)
98-
def loadNodes(is: InputStream): Seq[Node] = loadXMLNodes(fromInputStream(is), parser)
99-
def loadNodes(reader: Reader): Seq[Node] = loadXMLNodes(fromReader(reader), parser)
100-
def loadNodes(sysID: String): Seq[Node] = loadXMLNodes(fromSysId(sysID), parser)
101-
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, parser)
102-
def loadNodes(url: URL): Seq[Node] = loadXMLNodes(fromInputStream(url.openStream()), parser)
103-
def loadStringNodes(string: String): Seq[Node] = loadXMLNodes(fromString(string), parser)
113+
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, reader)
114+
def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file))
115+
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadNodes(fromFile(fd))
116+
def loadFileNodes(name: String): Seq[Node] = loadNodes(fromFile(name))
117+
def loadNodes(is: InputStream): Seq[Node] = loadNodes(fromInputStream(is))
118+
def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader))
119+
def loadNodes(sysID: String): Seq[Node] = loadNodes(fromSysId(sysID))
120+
def loadNodes(url: URL): Seq[Node] = loadNodes(fromInputStream(url.openStream()))
121+
def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string))
104122
}

shared/src/main/scala/scala/xml/package.scala

+1
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,6 @@ package object xml {
8080
type SAXParseException = org.xml.sax.SAXParseException
8181
type EntityResolver = org.xml.sax.EntityResolver
8282
type InputSource = org.xml.sax.InputSource
83+
type XMLReader = org.xml.sax.XMLReader
8384
type SAXParser = javax.xml.parsers.SAXParser
8485
}

shared/src/main/scala/scala/xml/parsing/MarkupParser.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ trait MarkupParser extends MarkupParserCommon with TokenTests {
9898
var extIndex = -1
9999

100100
/** holds temporary values of pos */
101-
// Note: this is clearly an override, but if marked as such it causes a "...cannot override a mutable variable"
102-
// error with Scala 3; does it work with Scala 3 if not explicitly marked as an override remains to be seen...
101+
// Note: if marked as an override, this causes a "...cannot override a mutable variable" error with Scala 3;
102+
// SethTisue noted on Oct 14, 2021 that lampepfl/dotty#13744 should fix it - and it probably did,
103+
// but Scala XML still builds against Scala 3 version that has this bug, so this still can not be marked as an override :(
103104
var tmppos: Int = _
104105

105106
/** holds the next character */

0 commit comments

Comments
 (0)