diff --git a/xml/src/main/scala/akka/stream/alpakka/xml/impl/StreamingXmlParser.scala b/xml/src/main/scala/akka/stream/alpakka/xml/impl/StreamingXmlParser.scala index faf0ad104e..cac45b0a06 100644 --- a/xml/src/main/scala/akka/stream/alpakka/xml/impl/StreamingXmlParser.scala +++ b/xml/src/main/scala/akka/stream/alpakka/xml/impl/StreamingXmlParser.scala @@ -17,21 +17,45 @@ import scala.annotation.tailrec private[xml] object StreamingXmlParser { lazy val withStreamingFinishedException = new IllegalStateException("Stream finished before event was fully parsed.") + + sealed trait ContextHandler[A, B, Ctx] { + def getByteString(a: A): ByteString + def getContext(a: A): Ctx + def buildOutput(pe: ParseEvent, ctx: Ctx): B + } + + object ContextHandler { + final val uncontextual: ContextHandler[ByteString, ParseEvent, Unit] = + new ContextHandler[ByteString, ParseEvent, Unit] { + def getByteString(a: ByteString): ByteString = a + def getContext(a: ByteString): Unit = () + def buildOutput(pe: ParseEvent, ctx: Unit): ParseEvent = pe + } + + final def contextual[Ctx]: ContextHandler[(ByteString, Ctx), (ParseEvent, Ctx), Ctx] = + new ContextHandler[(ByteString, Ctx), (ParseEvent, Ctx), Ctx] { + def getByteString(a: (ByteString, Ctx)): ByteString = a._1 + def getContext(a: (ByteString, Ctx)): Ctx = a._2 + def buildOutput(pe: ParseEvent, ctx: Ctx): (ParseEvent, Ctx) = (pe, ctx) + } + } } /** * INTERNAL API */ -@InternalApi private[xml] class StreamingXmlParser(ignoreInvalidChars: Boolean, - configureFactory: AsyncXMLInputFactory => Unit) - extends GraphStage[FlowShape[ByteString, ParseEvent]] { - val in: Inlet[ByteString] = Inlet("XMLParser.in") - val out: Outlet[ParseEvent] = Outlet("XMLParser.out") - override val shape: FlowShape[ByteString, ParseEvent] = FlowShape(in, out) +@InternalApi private[xml] class StreamingXmlParser[A, B, Ctx](ignoreInvalidChars: Boolean, + configureFactory: AsyncXMLInputFactory => Unit, + transform: StreamingXmlParser.ContextHandler[A, B, Ctx]) + extends GraphStage[FlowShape[A, B]] { + val in: Inlet[A] = Inlet("XMLParser.in") + val out: Outlet[B] = Outlet("XMLParser.out") + override val shape: FlowShape[A, B] = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { private var started: Boolean = false + private var context: Ctx = _ import javax.xml.stream.XMLStreamConstants @@ -45,7 +69,10 @@ private[xml] object StreamingXmlParser { setHandlers(in, out, this) override def onPush(): Unit = { - val array = grab(in).toArray + val a = grab(in) + val bs = transform.getByteString(a) + context = transform.getContext(a) + val array = bs.toArray parser.getInputFeeder.feedInput(array, 0, array.length) advanceParser() } @@ -67,10 +94,10 @@ private[xml] object StreamingXmlParser { case XMLStreamConstants.START_DOCUMENT => started = true - push(out, StartDocument) + push(out, transform.buildOutput(StartDocument, context)) case XMLStreamConstants.END_DOCUMENT => - push(out, EndDocument) + push(out, transform.buildOutput(EndDocument, context)) completeStage() case XMLStreamConstants.START_ELEMENT => @@ -91,27 +118,30 @@ private[xml] object StreamingXmlParser { val optNs = optPrefix.flatMap(prefix => Option(parser.getNamespaceURI(prefix))) push( out, - StartElement(parser.getLocalName, - attributes, - optPrefix.filterNot(_ == ""), - optNs.filterNot(_ == ""), - namespaceCtx = namespaces) + transform.buildOutput(StartElement(parser.getLocalName, + attributes, + optPrefix.filterNot(_ == ""), + optNs.filterNot(_ == ""), + namespaceCtx = namespaces), + context) ) case XMLStreamConstants.END_ELEMENT => - push(out, EndElement(parser.getLocalName)) + push(out, transform.buildOutput(EndElement(parser.getLocalName), context)) case XMLStreamConstants.CHARACTERS => - push(out, Characters(parser.getText)) + push(out, transform.buildOutput(Characters(parser.getText), context)) case XMLStreamConstants.PROCESSING_INSTRUCTION => - push(out, ProcessingInstruction(Option(parser.getPITarget), Option(parser.getPIData))) + push(out, + transform.buildOutput(ProcessingInstruction(Option(parser.getPITarget), Option(parser.getPIData)), + context)) case XMLStreamConstants.COMMENT => - push(out, Comment(parser.getText)) + push(out, transform.buildOutput(Comment(parser.getText), context)) case XMLStreamConstants.CDATA => - push(out, CData(parser.getText)) + push(out, transform.buildOutput(CData(parser.getText), context)) // Do not support DTD, SPACE, NAMESPACE, NOTATION_DECLARATION, ENTITY_DECLARATION, PROCESSING_INSTRUCTION // ATTRIBUTE is handled in START_ELEMENT implicitly diff --git a/xml/src/main/scala/akka/stream/alpakka/xml/javadsl/XmlParsing.scala b/xml/src/main/scala/akka/stream/alpakka/xml/javadsl/XmlParsing.scala index 85826fc6ec..ab9b0bd517 100644 --- a/xml/src/main/scala/akka/stream/alpakka/xml/javadsl/XmlParsing.scala +++ b/xml/src/main/scala/akka/stream/alpakka/xml/javadsl/XmlParsing.scala @@ -23,12 +23,28 @@ object XmlParsing { def parser(): akka.stream.javadsl.Flow[ByteString, ParseEvent, NotUsed] = xml.scaladsl.XmlParsing.parser.asJava + /** + * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX while keeping + * a context attached. + */ + def parserWithContext[Ctx](): akka.stream.javadsl.FlowWithContext[ByteString, Ctx, ParseEvent, Ctx, NotUsed] = + xml.scaladsl.XmlParsing.parserWithContext().asJava + /** * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX. */ def parser(ignoreInvalidChars: Boolean): akka.stream.javadsl.Flow[ByteString, ParseEvent, NotUsed] = xml.scaladsl.XmlParsing.parser(ignoreInvalidChars).asJava + /** + * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX while keeping + * a context attached. + */ + def parserWithContext[Ctx]( + ignoreInvalidChars: Boolean + ): akka.stream.javadsl.FlowWithContext[ByteString, Ctx, ParseEvent, Ctx, NotUsed] = + xml.scaladsl.XmlParsing.parserWithContext(ignoreInvalidChars).asJava + /** * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX. */ @@ -46,6 +62,16 @@ object XmlParsing { ): akka.stream.javadsl.Flow[ByteString, ParseEvent, NotUsed] = xml.scaladsl.XmlParsing.parser(ignoreInvalidChars, configureFactory.accept(_)).asJava + /** + * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX while keeping + * a context attached. + */ + def parserWithContext[Ctx]( + ignoreInvalidChars: Boolean, + configureFactory: Consumer[AsyncXMLInputFactory] + ): akka.stream.javadsl.FlowWithContext[ByteString, Ctx, ParseEvent, Ctx, NotUsed] = + xml.scaladsl.XmlParsing.parserWithContext(ignoreInvalidChars, configureFactory.accept(_)).asJava + /** * A Flow that transforms a stream of XML ParseEvents. This stage coalesces consequitive CData and Characters * events into a single Characters event or fails if the buffered string is larger than the maximum defined. diff --git a/xml/src/main/scala/akka/stream/alpakka/xml/scaladsl/XmlParsing.scala b/xml/src/main/scala/akka/stream/alpakka/xml/scaladsl/XmlParsing.scala index 7f0ee8c11c..9c1f6fb1a8 100644 --- a/xml/src/main/scala/akka/stream/alpakka/xml/scaladsl/XmlParsing.scala +++ b/xml/src/main/scala/akka/stream/alpakka/xml/scaladsl/XmlParsing.scala @@ -7,7 +7,7 @@ package akka.stream.alpakka.xml.scaladsl import akka.NotUsed import akka.stream.alpakka.xml.ParseEvent import akka.stream.alpakka.xml.impl -import akka.stream.scaladsl.Flow +import akka.stream.scaladsl.{Flow, FlowWithContext} import akka.util.ByteString import com.fasterxml.aalto.AsyncXMLInputFactory import org.w3c.dom.Element @@ -40,7 +40,31 @@ object XmlParsing { */ def parser(ignoreInvalidChars: Boolean = false, configureFactory: AsyncXMLInputFactory => Unit = configureDefault): Flow[ByteString, ParseEvent, NotUsed] = - Flow.fromGraph(new impl.StreamingXmlParser(ignoreInvalidChars, configureFactory)) + Flow[ByteString].via( + Flow.fromGraph( + new impl.StreamingXmlParser[ByteString, ParseEvent, Unit](ignoreInvalidChars, + configureFactory, + impl.StreamingXmlParser.ContextHandler.uncontextual) + ) + ) + + /** + * Parser Flow that takes a stream of ByteStrings and parses them to XML events similar to SAX while keeping + * a context attached. + */ + def parserWithContext[Ctx]( + ignoreInvalidChars: Boolean = false, + configureFactory: AsyncXMLInputFactory => Unit = configureDefault + ): FlowWithContext[ByteString, Ctx, ParseEvent, Ctx, NotUsed] = + FlowWithContext.fromTuples( + Flow.fromGraph( + new impl.StreamingXmlParser[(ByteString, Ctx), (ParseEvent, Ctx), Ctx]( + ignoreInvalidChars, + configureFactory, + impl.StreamingXmlParser.ContextHandler.contextual + ) + ) + ) /** * A Flow that transforms a stream of XML ParseEvents. This stage coalesces consecutive CData and Characters diff --git a/xml/src/test/scala/docs/scaladsl/XmlProcessingSpec.scala b/xml/src/test/scala/docs/scaladsl/XmlProcessingSpec.scala index d1dfbac428..20ff795021 100644 --- a/xml/src/test/scala/docs/scaladsl/XmlProcessingSpec.scala +++ b/xml/src/test/scala/docs/scaladsl/XmlProcessingSpec.scala @@ -8,7 +8,7 @@ import akka.actor.ActorSystem import akka.stream.alpakka.testkit.scaladsl.LogCapturing import akka.stream.alpakka.xml._ import akka.stream.alpakka.xml.scaladsl.XmlParsing -import akka.stream.scaladsl.{Flow, Keep, Sink, Source} +import akka.stream.scaladsl.{Flow, Framing, Keep, Sink, Source} import akka.util.ByteString import org.scalatest.concurrent.ScalaFutures import org.scalatest.BeforeAndAfterAll @@ -356,6 +356,45 @@ class XmlProcessingSpec extends AnyWordSpec with Matchers with ScalaFutures with configWasCalled shouldBe true } + "parse XML and attach line numbers as context" in { + val doc = """| + | + | elem1 + | + | + | elem2 + | + |""".stripMargin + val resultFuture = Source + .single(ByteString(doc)) + .via( + Framing.delimiter(delimiter = ByteString(System.lineSeparator), + maximumFrameLength = 65536, + allowTruncation = true) + ) + .zipWithIndex + .runWith(XmlParsing.parserWithContext[Long]().asFlow.toMat(Sink.seq)(Keep.right)) + + resultFuture.futureValue should ===( + List( + (StartDocument, 0L), + (StartElement("doc"), 0L), + (Characters(" "), 1L), + (StartElement("elem"), 1L), + (Characters(" elem1"), 2L), + (Characters(" "), 3L), + (EndElement("elem"), 3L), + (Characters(" "), 4L), + (StartElement("elem"), 4L), + (Characters(" elem2"), 5L), + (Characters(" "), 6L), + (EndElement("elem"), 6L), + (EndElement("doc"), 7L), + (EndDocument, 7L) + ) + ) + } + } override protected def afterAll(): Unit = system.terminate()