diff --git a/build.sbt b/build.sbt index bdd44a2e..e063fc0d 100644 --- a/build.sbt +++ b/build.sbt @@ -269,6 +269,19 @@ lazy val csv = Build ) .dependsOn(openapi) +lazy val xml = Build + .defineProject("xml") + .settings( + libraryDependencies ++= scalaXml + ) + .dependsOn(openapi) + +lazy val soap = Build + .defineProject("soap") + .settings( + ) + .dependsOn(xml) + lazy val root = (project in file(".")) .settings( name := "chopsticks", @@ -293,6 +306,8 @@ lazy val root = (project in file(".")) metric, openapi, csv, + xml, + soap, alertmanager, prometheus, zioGrpcCommon, diff --git a/chopsticks-csv/src/main/scala/dev/chopsticks/csv/CsvEncoder.scala b/chopsticks-csv/src/main/scala/dev/chopsticks/csv/CsvEncoder.scala index bf232c84..8cc78889 100644 --- a/chopsticks-csv/src/main/scala/dev/chopsticks/csv/CsvEncoder.scala +++ b/chopsticks-csv/src/main/scala/dev/chopsticks/csv/CsvEncoder.scala @@ -4,7 +4,7 @@ import dev.chopsticks.openapi.{OpenApiParsedAnnotations, OpenApiSumTypeSerDeStra import dev.chopsticks.openapi.common.{ConverterCache, OpenApiConverterUtils} import org.apache.commons.text.StringEscapeUtils import zio.schema.{Schema, StandardType, TypeId} -import zio.{Chunk, ChunkBuilder} +import zio.Chunk import zio.schema.Schema.{Field, Primitive} import java.time.{ @@ -58,27 +58,6 @@ final case class CsvEncodingResult(headers: Chunk[String], rows: Chunk[Chunk[Str trait CsvEncoder[A] { self => - def encodeSeq(values: Iterable[A]): CsvEncodingResult = { - val encodedValues = Chunk.fromIterable(values).map(v => encode(v)) - val headers = encodedValues - .foldLeft(mutable.SortedSet.empty[String]) { case (acc, next) => - acc ++ next.keys - } - val singleRowBuilder = ChunkBuilder.make[String](headers.size) - val rows = encodedValues - .foldLeft(ChunkBuilder.make[Chunk[String]](values.size)) { case (acc, next) => - singleRowBuilder.clear() - val row = headers - .foldLeft(singleRowBuilder) { case (rowBuilder, header) => - rowBuilder += next.getOrElse(header, "") - } - .result() - acc += row - } - .result() - - CsvEncodingResult(Chunk.fromIterable(headers), Chunk.fromIterable(rows)) - } def encode(value: A): mutable.LinkedHashMap[String, String] = encode(value, columnName = None, mutable.LinkedHashMap.empty) @@ -464,7 +443,7 @@ object CsvEncoder { val diff = discriminator.mapping.values.toSet.diff(encodersByName.keySet) if (diff.nonEmpty) { throw new RuntimeException( - s"Cannot derive CsvEncoder for ${enumAnnotations.entityName.getOrElse("-")}, because mapping and decoders don't match. Diff=$diff." + s"Cannot derive CsvEncoder for ${id.name}, because mapping and decoders don't match. Diff=$diff." ) } new CsvEncoder[A] { diff --git a/chopsticks-openapi/src/main/scala/dev/chopsticks/openapi/common/OpenApiConverterUtils.scala b/chopsticks-openapi/src/main/scala/dev/chopsticks/openapi/common/OpenApiConverterUtils.scala index 26f0169b..664de071 100644 --- a/chopsticks-openapi/src/main/scala/dev/chopsticks/openapi/common/OpenApiConverterUtils.scala +++ b/chopsticks-openapi/src/main/scala/dev/chopsticks/openapi/common/OpenApiConverterUtils.scala @@ -47,4 +47,25 @@ object OpenApiConverterUtils { } } } + + private[chopsticks] def isSeq(schema: Schema[_]): Boolean = { + schema match { + case _: Schema.Sequence[_, _, _] => true + case _: Schema.Set[_] => true + case _: Schema.Primitive[_] => false + case o: Schema.Optional[_] => isSeq(o.schema) + case t: Schema.Transform[_, _, _] => isSeq(t.schema) + case l: Schema.Lazy[_] => isSeq(l.schema) + case _: Schema.Record[_] => false + case _: Schema.Enum[_] => false + case _: Schema.Map[_, _] => false + case _: Schema.Either[_, _] => false + case _: Schema.Tuple2[_, _] => false + case _: Schema.Fail[_] => false + case _: Schema.Fallback[_, _] => + throw new IllegalArgumentException("Fallback schema is not supported") + case _: Schema.Dynamic => + throw new IllegalArgumentException("Dynamic schema is not supported") + } + } } diff --git a/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/Wsdl.scala b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/Wsdl.scala new file mode 100644 index 00000000..f336f5a4 --- /dev/null +++ b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/Wsdl.scala @@ -0,0 +1,177 @@ +package dev.chopsticks.soap.wsdl + +import dev.chopsticks.soap.XsdSchema +import dev.chopsticks.soap.wsdl.WsdlParsingError.{SoapEnvelopeParsing, XmlParsing} +import dev.chopsticks.xml.{XmlDecoder, XmlDecoderError, XmlEncoder} +import zio.Chunk + +import scala.collection.immutable.{ListMap, ListSet} +import scala.util.{Failure, Success, Try} +import scala.xml.{Elem, PrefixedAttribute} + +sealed trait WsdlParsingError extends Product with Serializable +object WsdlParsingError { + final case class UnknownAction(received: String, supported: Iterable[String]) extends WsdlParsingError + final case class XmlParsing(message: String, internalException: Throwable) extends WsdlParsingError + final case class SoapEnvelopeParsing(message: String) extends WsdlParsingError + final case class XmlDecoder(errors: List[XmlDecoderError]) extends WsdlParsingError +} + +// follows subset of the WSDL 1.1 spec +final case class Wsdl( + definitions: WsdlDefinitions, + portTypeName: String, + bindingName: String, + operations: ListSet[WsdlOperation] +) { + private lazy val operationsByName = operations.iterator.map(op => op.name -> op).to(ListMap) + + def addOperation(operation: WsdlOperation): Wsdl = { + copy(operations = this.operations + operation) + } + + def parseBodyPart1( + soapAction: String, + body: String + ): Either[WsdlParsingError, (WsdlOperation, Any)] = { + operationsByName.get(soapAction) match { + case None => Left(WsdlParsingError.UnknownAction(soapAction, operationsByName.keys)) + case Some(operation) => + Try(xml.XML.loadString(body)) match { + case Failure(e) => Left(XmlParsing(s"Provided XML is not valid.", e)) + case Success(xml) => + xml.child.collectFirst { case elem: Elem if elem.label == "Body" => elem } match { + case None => Left(SoapEnvelopeParsing("Body element not found in the provided XML.")) + case Some(bodyElem) => + bodyElem.child.collectFirst { case elem: Elem if elem.label == operation.name => elem } match { + case None => + Left(SoapEnvelopeParsing(s"${operation.name} element not found within the soapenv:Body.")) + case Some(operationElem) => + operation.input.parts.headOption match { + case None => Left(SoapEnvelopeParsing(s"No input parts found for operation ${operation.name}.")) + case Some(wsdlPart) => + operationElem.child.collectFirst { case e: Elem if e.label == wsdlPart.name => e } match { + case None => + Left(SoapEnvelopeParsing( + s"${wsdlPart.name} not found for operation ${operation.name} in the received XML." + )) + case Some(partElem) => + wsdlPart.xmlDecoder.parse(partElem.child) match { + case Left(errors) => Left(WsdlParsingError.XmlDecoder(errors)) + case Right(value) => Right((operation, value)) + } + } + } + } + } + } + } + } + + def serializeResponsePart1[A: XsdSchema](operation: WsdlOperation, response: A): Elem = { + if (operation.output.parts.size != 1) { + throw new RuntimeException( + s"Only single part output is supported. Got ${operation.output.parts.size} parts instead." + ) + } + if (operation.output.parts.head.xsdSchema != implicitly[XsdSchema[A]]) { + throw new RuntimeException("XsdSchema[A] must match the schema of the part.") + } + val part = operation.output.parts.head.asInstanceOf[WsdlMessagePart[A]] + Elem( + "soapenv", + "Envelope", + new PrefixedAttribute("xmlns", "soapenv", "http://schemas.xmlsoap.org/soap/envelope/", scala.xml.Null), + scala.xml.TopScope, + minimizeEmpty = true, + Elem( + "soapenv", + "Header", + scala.xml.Null, + scala.xml.TopScope, + minimizeEmpty = true + ), + Elem( + "soapenv", + "Body", + scala.xml.Null, + scala.xml.TopScope, + minimizeEmpty = true, + Elem( + definitions.custom.key, + part.xsdSchema.name, + new PrefixedAttribute( + "xmlns", + definitions.custom.key, + definitions.targetNamespace, + new PrefixedAttribute("soapenv", "encodingStyle", "http://schemas.xmlsoap.org/soap/encoding/", xml.Null) + ), + scala.xml.TopScope, + minimizeEmpty = true, + Elem( + null, + part.name, + xml.Null, + xml.TopScope, + minimizeEmpty = true, + part.xmlEncoder.encode(response): _* + ) + ) + ) + ) + } +} + +object Wsdl { + def withDefinitions( + targetNamespace: String, + portTypeName: String, + bindingName: String, + definition: WsdlDefinition + ): Wsdl = { + Wsdl( + definitions = WsdlDefinitions( + targetNamespace = targetNamespace, + custom = definition + ), + portTypeName = portTypeName, + bindingName = bindingName, + operations = ListSet.empty + ) + } +} + +final case class WsdlDefinitions( + targetNamespace: String, + xmlns: WsdlDefinitionAddress = WsdlDefinition.wsdl.address, + wsdl: WsdlDefinitionAddress = WsdlDefinition.wsdl.address, + wsdlsoap: WsdlDefinitionAddress = WsdlDefinition.wsdlsoap.address, + xsd: WsdlDefinitionAddress = WsdlDefinition.xsd.address, + custom: WsdlDefinition +) { + lazy val defs = Chunk[WsdlDefinition]( + WsdlDefinition(WsdlDefinition.wsdl.key, wsdl), + WsdlDefinition(WsdlDefinition.wsdlsoap.key, wsdlsoap), + WsdlDefinition(WsdlDefinition.xsd.key, xsd), + custom + ) +} + +final case class WsdlDefinition(key: String, address: WsdlDefinitionAddress) +object WsdlDefinition { + val wsdl = WsdlDefinition("wsdl", WsdlDefinitionAddress("http://schemas.xmlsoap.org/wsdl/")) + val wsdlsoap = WsdlDefinition("wsdlsoap", WsdlDefinitionAddress("http://schemas.xmlsoap.org/wsdl/soap/")) + val xsd = WsdlDefinition("xsd", WsdlDefinitionAddress("http://www.w3.org/2001/XMLSchema")) +} + +final case class WsdlDefinitionAddress(value: String) extends AnyVal + +final case class WsdlMessage(name: String, parts: ListSet[WsdlMessagePart[_]]) + +final case class WsdlMessagePart[A](name: String)(implicit + val xmlEncoder: XmlEncoder[A], + val xmlDecoder: XmlDecoder[A], + val xsdSchema: XsdSchema[A] +) + +final case class WsdlOperation(name: String, input: WsdlMessage, output: WsdlMessage) diff --git a/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/WsdlSchemaPrinter.scala b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/WsdlSchemaPrinter.scala new file mode 100644 index 00000000..37285ab4 --- /dev/null +++ b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/wsdl/WsdlSchemaPrinter.scala @@ -0,0 +1,228 @@ +package dev.chopsticks.soap.wsdl + +import dev.chopsticks.soap.XsdType +import dev.chopsticks.xml.XmlPrettyPrinter + +import scala.xml.{Elem, MetaData, Node, NodeSeq, PrefixedAttribute, TopScope, UnprefixedAttribute} + +object WsdlSchemaPrinter { + + def printSchema(wsdl: Wsdl): String = { + val xmlDoc = Elem( + "wsdl", + "definitions", + new UnprefixedAttribute( + "targetNamespace", + wsdl.definitions.targetNamespace, + new UnprefixedAttribute( + "xmlns", + wsdl.definitions.xmlns.value, + wsdl.definitions.defs.foldRight(scala.xml.Null: MetaData) { case (definition, nextAttr) => + new PrefixedAttribute("xmlns", definition.key, definition.address.value, nextAttr) + } + ) + ), + TopScope, + minimizeEmpty = true, + NodeSeq + .newBuilder + .addOne(wsdlTypes(wsdl)) + .addAll(wsdlMessages(wsdl)) + .addOne(portType(wsdl)) + .addOne(bindings(wsdl)) + .result(): _* + ) + s""" + |${renderAsXml(xmlDoc)}""".stripMargin + } + + private def wsdlTypes(wsdl: Wsdl): Elem = { + val messageTypes = wsdl.operations.iterator + .flatMap(op => op.input.parts.iterator.concat(op.output.parts.iterator)) + .map(_.xsdSchema) + val types = + messageTypes.foldLeft(scala.collection.mutable.LinkedHashMap.empty[(XsdType, String), Elem]) { (acc, schema) => + schema.collectXsdTypes(wsdl.definitions.custom.key, acc) + } + val sortedKeys = types.keys.toSeq.sortBy { case (t, n) => (t.orderPriority, n) } + val elements = sortedKeys.map(k => types(k)) + Elem( + "wsdl", + "types", + xml.Null, + TopScope, + minimizeEmpty = true, + Elem( + "xsd", + "schema", + new UnprefixedAttribute("targetNamespace", wsdl.definitions.targetNamespace, scala.xml.Null), + TopScope, + minimizeEmpty = true, + elements: _* + ) + ) + } + + private def wsdlMessages(wsdl: Wsdl): NodeSeq = { + wsdl + .operations + .iterator + .flatMap { op => + Iterator(op.input, op.output) + } + .map { message => + Elem( + "wsdl", + "message", + new UnprefixedAttribute("name", message.name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + message.parts.toList.map { part => + Elem( + "wsdl", + "part", + new UnprefixedAttribute( + "name", + part.name, + new UnprefixedAttribute( + "type", + part.xsdSchema.namespacedName(wsdl.definitions.custom.key), + scala.xml.Null + ) + ), + TopScope, + minimizeEmpty = true + ) + }: _* + ) + } + .foldLeft(NodeSeq.newBuilder)((acc, elem) => acc += elem) + .result() + } + + private def portType(wsdl: Wsdl): Elem = { + val operations = wsdl.operations.iterator + .map { op => + Elem( + "wsdl", + "operation", + new UnprefixedAttribute("name", op.name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + Elem( + "wsdl", + "input", + new UnprefixedAttribute("message", s"${wsdl.definitions.custom.key}:${op.input.name}", scala.xml.Null), + TopScope, + minimizeEmpty = true + ), + Elem( + "wsdl", + "output", + new UnprefixedAttribute("message", s"${wsdl.definitions.custom.key}:${op.output.name}", scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + } + .toVector + Elem( + "wsdl", + "portType", + new UnprefixedAttribute("name", wsdl.portTypeName, scala.xml.Null), + TopScope, + minimizeEmpty = true, + operations: _* + ) + } + + private def bindings(wsdl: Wsdl) = { + Elem( + "wsdl", + "binding", + new UnprefixedAttribute( + "name", + wsdl.bindingName, + new UnprefixedAttribute("type", s"${wsdl.definitions.custom.key}:${wsdl.portTypeName}", scala.xml.Null) + ), + TopScope, + minimizeEmpty = true, + Iterator + .single { + Elem( + "wsdlsoap", + "binding", + new UnprefixedAttribute( + "style", + "rpc", + new UnprefixedAttribute("transport", "http://schemas.xmlsoap.org/soap/http", scala.xml.Null) + ), + TopScope, + minimizeEmpty = true + ) + } + .concat { + wsdl.operations.toVector.map { op => + bindingOperation(wsdl, op) + } + } + .toVector: _* + ) + } + + private def bindingOperation(wsdl: Wsdl, operation: WsdlOperation): Elem = { + Elem( + "wsdl", + "operation", + new UnprefixedAttribute("name", operation.name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + Elem( + "wsdlsoap", + "operation", + new UnprefixedAttribute("soapAction", s"${operation.name}", scala.xml.Null), + TopScope, + minimizeEmpty = true + ), + Elem( + "wsdl", + "input", + scala.xml.Null, + TopScope, + minimizeEmpty = true, + wsdlSoapBody(wsdl) + ), + Elem( + "wsdl", + "output", + scala.xml.Null, + TopScope, + minimizeEmpty = true, + wsdlSoapBody(wsdl) + ) + ) + } + + private def wsdlSoapBody(wsdl: Wsdl): Elem = { + Elem( + "wsdlsoap", + "body", + new UnprefixedAttribute( + "encodingStyle", + "http://schemas.xmlsoap.org/soap/encoding/", + new UnprefixedAttribute( + "namespace", + wsdl.definitions.targetNamespace, + new UnprefixedAttribute("use", "encoded", scala.xml.Null) + ) + ), + TopScope, + minimizeEmpty = true + ) + } + + private def renderAsXml(value: Node): String = { + new XmlPrettyPrinter(120, 2, minimizeEmpty = true).format(value) + } + +} diff --git a/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdAnnotations.scala b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdAnnotations.scala new file mode 100644 index 00000000..8406c7fc --- /dev/null +++ b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdAnnotations.scala @@ -0,0 +1,23 @@ +package dev.chopsticks.soap.xsd + +import dev.chopsticks.xml.XmlAnnotations +import zio.Chunk + +final private[chopsticks] case class XsdAnnotations[A]( + xmlAnnotations: XmlAnnotations[A], + xsdSchemaName: Option[String] = None +) + +object XsdAnnotations { + final case class xsdSchemaName(name: String) extends scala.annotation.StaticAnnotation + + private[chopsticks] def extractAnnotations[A](annotations: Chunk[Any]): XsdAnnotations[A] = { + val xmlAnnotations = XmlAnnotations.extractAnnotations[A](annotations) + annotations.foldLeft(XsdAnnotations[A](xmlAnnotations)) { case (typed, annotation) => + annotation match { + case a: xsdSchemaName => typed.copy(xsdSchemaName = Some(a.name)) + case _ => typed + } + } + } +} diff --git a/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchema.scala b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchema.scala new file mode 100644 index 00000000..5e0c0813 --- /dev/null +++ b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchema.scala @@ -0,0 +1,643 @@ +package dev.chopsticks.soap + +import dev.chopsticks.openapi.common.{ConverterCache, OpenApiConverterUtils} +import dev.chopsticks.openapi.OpenApiSumTypeSerDeStrategy +import dev.chopsticks.soap.xsd.XsdAnnotations +import dev.chopsticks.soap.xsd.XsdAnnotations.extractAnnotations +import sttp.tapir.Validator +import zio.schema.{Schema, StandardType, TypeId} +import zio.Chunk + +import scala.collection.immutable.ListMap +import scala.collection.mutable +import scala.xml.{Elem, NodeSeq, TopScope, UnprefixedAttribute} + +sealed trait XsdSimpleType { + def name: String +} +object XsdSimpleType { + final case object XsdBoolean extends XsdSimpleType { + override def name: String = "xsd:boolean" + } + final case object XsdInteger extends XsdSimpleType { + override def name: String = "xsd:integer" + } + final case object XsdString extends XsdSimpleType { + override def name: String = "xsd:string" + } + final case object XsdHexBinary extends XsdSimpleType { + override def name: String = "xsd:hexBinary" + } +} + +sealed trait XsdType extends Product with Serializable { + def orderPriority: Int +} +object XsdType { + final case object Simple extends XsdType { + override def orderPriority: Int = 1 + } + final case object Complex extends XsdType { + override def orderPriority: Int = 2 + } +} + +sealed trait XsdSchema[A] { + def name: String + def namespacedName(xsdNamespaceName: String): String + def xsdTypes(xsdNamespaceName: String): ListMap[(XsdType, String), Elem] = + collectXsdTypes(xsdNamespaceName, mutable.LinkedHashMap.empty).to(ListMap) + private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] +} +object XsdSchema { + def derive[A]()(implicit schema: Schema[A]): XsdSchema[A] = { + new Converter().convert(schema, None) + } + + // taken from tapir apispec-docs + private def asPrimitiveValidators[A](v: Validator[A]): Seq[Validator.Primitive[_]] = { + def toPrimitives(v: Validator[_]): Seq[Validator.Primitive[_]] = { + v match { + case Validator.Mapped(wrapped, _) => toPrimitives(wrapped) + case Validator.All(validators) => validators.flatMap(toPrimitives) + case Validator.Any(validators) => validators.flatMap(toPrimitives) + case Validator.Custom(_, _) => Nil + case bv: Validator.Primitive[_] => List(bv) + } + } + toPrimitives(v) + } + + private def collectPrimitiveConstraints[A](xs: Seq[Validator.Primitive[_]]) = { + xs + .foldLeft(NodeSeq.newBuilder) { (acc, v) => + v match { + case Validator.Min(value, exclusive) => + acc.addOne( + Elem( + "xsd", + if (exclusive) "minExclusive" else "minInclusive", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + case Validator.Max(value, exclusive) => + acc.addOne( + Elem( + "xsd", + if (exclusive) "maxExclusive" else "maxInclusive", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + case Validator.MinLength(value, _) => + acc.addOne( + Elem( + "xsd", + "minLength", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + case Validator.MaxLength(value, _) => + acc.addOne( + Elem( + "xsd", + "maxLength", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + case Validator.Pattern(value) => + acc.addOne( + Elem( + "xsd", + "pattern", + new UnprefixedAttribute("value", value, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + + case Validator.MaxSize(value) => + acc.addOne( + Elem( + "xsd", + "maxLength", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + + case Validator.MinSize(value) => + acc.addOne( + Elem( + "xsd", + "minLength", + new UnprefixedAttribute("value", value.toString, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + + case Validator.Enumeration(possibleValues, encode, _) => + val encodeValue = encode + .map(v => v.andThen(_.map(_.toString))) + .getOrElse((v: Any) => Some(v.toString)) + possibleValues.foreach { v => + encodeValue(v) match { + case None => () + case Some(encoded) => + acc.addOne( + Elem( + "xsd", + "enumeration", + new UnprefixedAttribute("value", encoded, scala.xml.Null), + TopScope, + minimizeEmpty = true + ) + ) + } + } + acc + + case Validator.Custom(_, _) => acc + } + } + .result() + } + + final case class Primitive[A](base: XsdSimpleType) extends XsdSchema[A] { + override def namespacedName(xsdNamespaceName: String) = base.name + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = acc + override def name: String = base.name + } + + final case class LazyXsdSchema[A]() extends XsdSchema[A] with ConverterCache.Lazy[XsdSchema[A]] { + def getSchema(): XsdSchema[A] = get + override def namespacedName(xsdNamespaceName: String) = get.namespacedName(xsdNamespaceName) + override def name: String = get.name + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = + get.collectXsdTypes(xsdNamespaceName, acc) + } + + final case class SimpleDerived[A]( + name: String, + base: XsdSimpleType, + validator: Validator[A] + ) extends XsdSchema[A] { + override def namespacedName(xsdNamespaceName: String): String = s"$xsdNamespaceName:$name" + + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = { + if (acc.contains((XsdType.Simple, name))) acc + else { + val primitiveValidators = asPrimitiveValidators(validator) + val constraints = collectPrimitiveConstraints(primitiveValidators) + val xmlElem = Elem( + "xsd", + "simpleType", + new UnprefixedAttribute("name", name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + Elem( + "xsd", + "restriction", + new UnprefixedAttribute("base", base.name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + constraints: _* + ) + ) + acc += ((XsdType.Simple, name) -> xmlElem) + } + } + } + + final case class Complex[A]( + name: String, + validator: Validator[A], + collectTypes: ( + String, + String, + mutable.LinkedHashMap[(XsdType, String), Elem] + ) => mutable.LinkedHashMap[(XsdType, String), Elem] + ) extends XsdSchema[A] { + override def namespacedName(xsdNamespaceName: String): String = s"$xsdNamespaceName:$name" + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = { + collectTypes(xsdNamespaceName, name, acc) + } + } + + final case class XsdOptional[A](underlying: XsdSchema[A]) extends XsdSchema[A] { + override def name: String = underlying.name + + override def namespacedName(xsdNamespaceName: String): String = underlying.namespacedName(xsdNamespaceName) + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = { + underlying.collectXsdTypes(xsdNamespaceName, acc) + } + } + + final case class XsdSequence[A](underlying: XsdSchema[A], nodeName: String, validator: Validator[A]) + extends XsdSchema[A] { + override def name: String = underlying.name + override def namespacedName(xsdNamespaceName: String): String = underlying.namespacedName(xsdNamespaceName) + + override private[chopsticks] def collectXsdTypes( + xsdNamespaceName: String, + acc: mutable.LinkedHashMap[(XsdType, String), Elem] + ): mutable.LinkedHashMap[(XsdType, String), Elem] = { + underlying.collectXsdTypes(xsdNamespaceName, acc) + } + } + + private val _boolSchemaType = Primitive[Boolean](XsdSimpleType.XsdBoolean) + private val _intSchemaType = Primitive[Int](XsdSimpleType.XsdInteger) + private val _stringSchemaType = Primitive[String](XsdSimpleType.XsdString) + private val _hexBinarySchemaType = Primitive[Array[Byte]](XsdSimpleType.XsdHexBinary) + + private def boolSchema[A] = _boolSchemaType.asInstanceOf[Primitive[A]] + private def intSchema[A] = _intSchemaType.asInstanceOf[Primitive[A]] + private def stringSchema[A] = _stringSchemaType.asInstanceOf[Primitive[A]] + private def hexBinarySchema[A] = _hexBinarySchemaType.asInstanceOf[Primitive[A]] + + private class Converter(cache: ConverterCache[XsdSchema] = new ConverterCache[XsdSchema]) { + private def convertUsingCache[A]( + typeId: TypeId, + annotations: XsdAnnotations[A] + )(convert: => XsdSchema[A]): XsdSchema[A] = { + cache + .convertUsingCache(typeId, annotations.xmlAnnotations.openApiAnnotations)(convert) { () => + new LazyXsdSchema[A]() + } + } + + private[chopsticks] def convert[A](schema: Schema[A], xmlSeqNodeName: Option[String]): XsdSchema[A] = { + schema match { + case Schema.Primitive(standardType, annotations) => + primitiveConverter(standardType, annotations) + + case Schema.Sequence(schemaA, _, _, annotations, _) => + val parsed = extractAnnotations[A](annotations) + val nodeName = parsed.xmlAnnotations.xmlFieldName + .orElse(xmlSeqNodeName) + .getOrElse { + throw new RuntimeException("Sequence must have xmlFieldName annotation") + } + addAnnotations[A]( + None, + XsdSequence(convert(schemaA, Some(nodeName)), nodeName, Validator.pass).asInstanceOf[XsdSchema[A]], + parsed + ) + + case Schema.Set(_, _) => + ??? + + case Schema.Transform(schema, _, _, annotations, _) => + val typedAnnotations = extractAnnotations[A](annotations) + val baseEncoder = convert(schema, typedAnnotations.xmlAnnotations.xmlFieldName.orElse(xmlSeqNodeName)) + .asInstanceOf[XsdSchema[A]] + addAnnotations(None, baseEncoder, typedAnnotations) + + case Schema.Optional(schema, annotations) => + addAnnotations[A]( + None, + baseSchema = XsdOptional(convert(schema, xmlSeqNodeName)).asInstanceOf[XsdSchema[A]], + metadata = extractAnnotations(annotations) + ) + + case l @ Schema.Lazy(_) => + convert(l.schema, xmlSeqNodeName) + + case s: Schema.Record[A] => + convertRecord[A](s) + + case s: Schema.Enum[A] => + convertEnum[A](s) + + case other => + throw new IllegalArgumentException(s"Unsupported schema type: $other") + } + } + + private def findChildTypeName(namespace: String, xsdSchema: XsdSchema[_]): String = xsdSchema match { + case o: XsdOptional[_] => findChildTypeName(namespace, o.underlying) + case l: LazyXsdSchema[_] => findChildTypeName(namespace, l.getSchema()) + case _: Primitive[_] => xsdSchema.name + case _ => s"$namespace:${xsdSchema.name}" + } + + private def convertRecord[A](r: Schema.Record[A]): XsdSchema[A] = { + + def isOptional(s: XsdSchema[_]): Boolean = s match { + case _: XsdOptional[_] => true + case xs: XsdSequence[_] => isOptional(xs.underlying) + case _ => false + } + + def minOccurs(s: XsdSchema[_]): Option[Int] = { + s match { + case _: XsdOptional[_] => Some(0) + case xs: XsdSequence[_] => + asPrimitiveValidators(xs.validator).collectFirst { + case Validator.MinSize(value) => value + } + case _ => None + } + } + + def maxOccurs(s: XsdSchema[_]): Option[Int] = s match { + case o: XsdOptional[_] => maxOccurs(o.underlying) + case xs: XsdSequence[_] => + asPrimitiveValidators(xs.validator).collectFirst { + case Validator.MaxSize(value) => value + } + case _ => None + } + + def asSeq(s: XsdSchema[_]): Option[XsdSequence[_]] = s match { + case o: XsdOptional[_] => asSeq(o.underlying) + case s: XsdSequence[_] => Some(s) + case _ => None + } + + val recordAnnotations: XsdAnnotations[A] = extractAnnotations[A](r.annotations) + convertUsingCache(r.id, recordAnnotations) { + val fieldEncoders = r.fields + .map { field => + val fieldAnnotations = extractAnnotations[Any](field.annotations) + addAnnotations[Any]( + None, + convert[Any]( + field.schema.asInstanceOf[Schema[Any]], + fieldAnnotations.xmlAnnotations.xmlFieldName.orElse(Some(field.name)) + ), + fieldAnnotations + ) + } + val fieldNames = r.fields.map { field => + extractAnnotations[Any](field.annotations).xmlAnnotations.xmlFieldName + } + + val baseEncoder = Complex[A]( + name = "", // placeholder, will be replaced by the addAnnotations + validator = Validator.pass, + collectTypes = (namespace, name, acc) => { + var result = acc + val fieldsBuilder = NodeSeq.newBuilder + var i = 0 + while (i < r.fields.length) { + val field = r.fields(i) + val encoder = fieldEncoders(i) + result = encoder.collectXsdTypes(namespace, result) + val childTypeName = findChildTypeName(namespace, encoder) + val fieldName = fieldNames(i).getOrElse(field.name) + val encoderAsSeq = asSeq(encoder) + val isSequence = encoderAsSeq.isDefined + val maxOccursTag = { + if (isSequence) { + val max = maxOccurs(encoder) + .map(_.toString) + .getOrElse("unbounded") + new UnprefixedAttribute("maxOccurs", max, scala.xml.Null) + } + else { + scala.xml.Null + } + } + + val minOccursTag = { + Option + .when(isOptional(encoder))(0) + .orElse(minOccurs(encoder)) match { + case Some(min) => new UnprefixedAttribute("minOccurs", min.toString, maxOccursTag) + case None => maxOccursTag + } + } + + val newElem = scala.xml.Elem( + "xsd", + "element", + new UnprefixedAttribute( + "name", + encoderAsSeq.map(_.nodeName).getOrElse(fieldName), + new UnprefixedAttribute( + "type", + childTypeName, + minOccursTag + ) + ), + TopScope, + minimizeEmpty = true + ) + val _ = fieldsBuilder.addOne(newElem) + i += 1 + } + val fieldElems = fieldsBuilder.result() + + val complexTypeElem = Elem( + "xsd", + "complexType", + new UnprefixedAttribute("name", name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + Elem( + "xsd", + "sequence", + scala.xml.Null, + TopScope, + minimizeEmpty = true, + fieldElems: _* + ) + ) + result += ((XsdType.Complex, name) -> complexTypeElem) + } + ) + addAnnotations(Some(r.id), baseEncoder, recordAnnotations) + } + } + + private def convertEnum[A](s: Schema.Enum[A]): XsdSchema[A] = { + val enumAnnotations = extractAnnotations[A](s.annotations) + val serDeStrategy = enumAnnotations.xmlAnnotations.openApiAnnotations.sumTypeSerDeStrategy + .getOrElse { + throw new RuntimeException( + s"Discriminator must be defined to derive an XsdSchema. Received annotations: $enumAnnotations" + ) + } + serDeStrategy match { + case OpenApiSumTypeSerDeStrategy.Discriminator(discriminator) => + val reversedDiscriminator = discriminator.mapping.map(_.swap) + if (reversedDiscriminator.size != discriminator.mapping.size) { + throw new RuntimeException( + s"Cannot derive XsdSchema for ${s.id.name}, because discriminator mapping is not unique." + ) + } + val schemasByDiscType = s.cases.iterator + .map { c => + val schema = addAnnotations( + None, + convert(c.schema, None), + extractAnnotations(c.annotations) + ) + reversedDiscriminator(c.caseName) -> (schema, c) + } + .toMap + val diff = discriminator.mapping.keySet.diff(schemasByDiscType.keySet) + if (diff.nonEmpty) { + throw new RuntimeException( + s"Cannot derive XsdSchema for ${s.id.name}, because mapping and decoders don't match. Diff=$diff." + ) + } + val schema = + Complex[A]( + name = "", + validator = Validator.pass, + collectTypes = (namespace, name, acc) => { + var result = acc + result = schemasByDiscType.iterator.foldLeft(result) { case (a, (_, (s, _))) => + s.collectXsdTypes(namespace, a) + } + val subtypesDefinition = schemasByDiscType.iterator + .foldLeft(NodeSeq.newBuilder) { case (builder, (tpe, (schema, _))) => + builder += Elem( + "xsd", + "element", + new UnprefixedAttribute( + "name", + tpe, + new UnprefixedAttribute("type", findChildTypeName(namespace, schema), scala.xml.Null) + ), + TopScope, + minimizeEmpty = true + ) + } + .result() + val xmlDefinition = Elem( + "xsd", + "complexType", + new UnprefixedAttribute("name", name, scala.xml.Null), + TopScope, + minimizeEmpty = true, + Elem( + "xsd", + "choice", + scala.xml.Null, + TopScope, + minimizeEmpty = true, + subtypesDefinition: _* + ) + ) + result += ((XsdType.Complex, name) -> xmlDefinition) + } + ) + + addAnnotations(Some(s.id), schema, enumAnnotations) + } + + } + + private def primitiveConverter[A]( + standardType: StandardType[A], + annotations: Chunk[Any] + ): XsdSchema[A] = { + val baseEncoder = standardType match { +// case StandardType.UnitType => unitEncoder + case StandardType.StringType => stringSchema[A] + case StandardType.BoolType => boolSchema[A] +// case StandardType.ByteType => byteEncoder +// case StandardType.ShortType => + case StandardType.IntType => intSchema[A] + case StandardType.BinaryType => hexBinarySchema[A] + case other => + throw new IllegalArgumentException(s"Unsupported standard type: $other") + } + addAnnotations(None, baseEncoder.asInstanceOf[XsdSchema[A]], extractAnnotations(annotations)) + } + + private def addAnnotations[A]( + typeId: Option[TypeId], + baseSchema: XsdSchema[A], + metadata: XsdAnnotations[A] + ): XsdSchema[A] = { + val name = metadata + .xsdSchemaName + .orElse(OpenApiConverterUtils.getEntityName(typeId, metadata.xmlAnnotations.openApiAnnotations)) + // todo take default into account? + baseSchema match { + case Primitive(base) => + metadata.xmlAnnotations.openApiAnnotations.validator match { + case Some(_) if name.isEmpty => + throw new IllegalArgumentException( + "xsdSchemaName or entityName annotation is required for simple types with validator" + ) + case Some(validator) => + SimpleDerived(name.get, base, validator) + case None if name.isEmpty => baseSchema + case None => SimpleDerived(name.get, base, Validator.pass) + } + case d: SimpleDerived[A] => + var result = d + result = metadata.xmlAnnotations.openApiAnnotations.validator match { + case Some(v) => + if (result.validator == Validator.pass) result.copy(validator = v) + else result.copy(validator = result.validator.and(v)) + case None => result + } + result = name match { + case Some(value) => result.copy(name = value) + case None => result + } + result + case c: Complex[A] => + var result = c + result = metadata.xmlAnnotations.openApiAnnotations.validator match { + case Some(v) => + if (result.validator == Validator.pass) result.copy(validator = v) + else result.copy(validator = result.validator.and(v)) + case None => result + } + result = name match { + case Some(value) => result.copy(name = value) + case None => result + } + result + case XsdOptional(_) => baseSchema + case xs: XsdSequence[A] => + var result = xs + result = metadata.xmlAnnotations.openApiAnnotations.validator match { + case Some(v) => + if (result.validator == Validator.pass) result.copy(validator = v) + else result.copy(validator = result.validator.and(v)) + case None => result + } + result + case LazyXsdSchema() => ??? + } + } + + } + +} diff --git a/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchemaPrinter.scala b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchemaPrinter.scala new file mode 100644 index 00000000..f126317c --- /dev/null +++ b/chopsticks-soap/src/main/scala/dev/chopsticks/soap/xsd/XsdSchemaPrinter.scala @@ -0,0 +1,29 @@ +package dev.chopsticks.soap.xsd + +import dev.chopsticks.soap.XsdSchema +import dev.chopsticks.xml.XmlPrettyPrinter + +import scala.xml.{Elem, TopScope, UnprefixedAttribute} + +object XsdSchemaPrinter { + + def printSchema[A](targetNamespaceName: String, namespaceName: String, schema: XsdSchema[A]): String = { + val types = schema.xsdTypes(namespaceName) + val sortedKeys = types.keys.toSeq.sortBy { case (t, n) => (t.orderPriority, n) } + val elements = sortedKeys.map(k => types(k)) + val xmlDoc = Elem( + "xsd", + "schema", + new UnprefixedAttribute("targetNamespace", targetNamespaceName, scala.xml.Null), + TopScope, + minimizeEmpty = true, + elements: _* + ) + renderAsXml(xmlDoc) + } + + private def renderAsXml(value: Elem): String = { + new XmlPrettyPrinter(120, 2, minimizeEmpty = true).format(value) + } + +} diff --git a/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlAnnotations.scala b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlAnnotations.scala new file mode 100644 index 00000000..d0561adc --- /dev/null +++ b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlAnnotations.scala @@ -0,0 +1,23 @@ +package dev.chopsticks.xml + +import dev.chopsticks.openapi.OpenApiParsedAnnotations +import zio.Chunk + +final private[chopsticks] case class XmlAnnotations[A]( + openApiAnnotations: OpenApiParsedAnnotations[A], + xmlFieldName: Option[String] = None +) + +object XmlAnnotations { + final case class xmlFieldName(name: String) extends scala.annotation.StaticAnnotation + + private[chopsticks] def extractAnnotations[A](annotations: Chunk[Any]): XmlAnnotations[A] = { + val openApiAnnotations = OpenApiParsedAnnotations.extractAnnotations[A](annotations) + annotations.foldLeft(XmlAnnotations[A](openApiAnnotations)) { case (typed, annotation) => + annotation match { + case a: xmlFieldName => typed.copy(xmlFieldName = Some(a.name)) + case _ => typed + } + } + } +} diff --git a/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlDecoder.scala b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlDecoder.scala new file mode 100644 index 00000000..d3653732 --- /dev/null +++ b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlDecoder.scala @@ -0,0 +1,642 @@ +package dev.chopsticks.xml + +import dev.chopsticks.openapi.common.{ConverterCache, OpenApiConverterUtils} +import dev.chopsticks.openapi.{OpenApiSumTypeSerDeStrategy, OpenApiValidation} +import dev.chopsticks.util.Hex +import dev.chopsticks.xml.XmlAnnotations.extractAnnotations +import sttp.tapir.Validator +import zio.schema.{Schema, StandardType, TypeId} +import zio.schema.Schema.Primitive +import zio.Chunk + +import java.time.{ + DayOfWeek, + Duration, + Instant, + LocalDate, + LocalDateTime, + LocalTime, + OffsetDateTime, + OffsetTime, + ZoneId, + ZoneOffset, + ZonedDateTime +} +import java.util.UUID +import scala.collection.mutable.ListBuffer +import scala.util.control.NonFatal +import scala.xml.{Elem, NodeSeq, Text} + +final case class XmlDecoderError(message: String, private[chopsticks] val reversedPath: List[String]) { +// def format: String = { +// val trimmed = message.trim +// val withDot = if (trimmed.endsWith(".")) trimmed else trimmed + "." +// withDot + (if (columnName.isDefined) s" Column(s): ${columnName.get}" else "") +// } + lazy val path: List[String] = reversedPath.reverse +} +object XmlDecoderError { +// def columnNotExists(columnName: String): XmlDecoderError = { +// XmlDecoderError(s"Required column does not exist.", columnName = Some(columnName)) +// } +// def notAllRequiredColumnsExist(columnName: Option[String]) = { +// XmlDecoderError("Not all required columns contain defined values.", columnName) +// } +// def notAllRequiredColumnsExist(columnNames: List[String]) = { +// XmlDecoderError("Not all required columns contain defined values.", Some(columnNames.mkString(", "))) +// } +// def unrecognizedDiscriminatorType(received: String, formattedKnownTypes: String, columnName: Option[String]) = { +// XmlDecoderError( +// s"Unrecognized type: ${received}. Valid object types are: $formattedKnownTypes.", +// columnName +// ) +// } +} + +trait XmlDecoder[A] { self => + def isOptional: Boolean + def parse(node: NodeSeq): Either[List[XmlDecoderError], A] = parse(node, Nil) + def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] + final def ensure(pred: A => Boolean, message: => String) = + new XmlDecoder[A] { + override val isOptional = self.isOptional +// override val isPrimitive = self.isPrimitive +// override def parseAsOption(row: Map[String, String], columnName: Option[String]) = { +// self.parseAsOption(row, columnName) match { +// case r @ Right(Some(value)) => +// if (pred(value)) r else Left(List(XmlDecoderError(message, errorColumn(options, columnName)))) +// case r @ Right(None) => r +// case l @ Left(_) => l +// } +// } + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = + self.parse(node, path) match { + case r @ Right(value) => + if (pred(value)) r + else Left(List(XmlDecoderError(message, path))) + case l @ Left(_) => l + } + } + + final def ensure(errors: A => List[String]) = new XmlDecoder[A] { + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = { + self.parse(node, path) match { + case r @ Right(value) => + errors(value) match { + case Nil => r + case errs => Left(errs.map(e => XmlDecoderError(e, path))) + } + case l @ Left(_) => l + } + } + override val isOptional = self.isOptional +// override val isPrimitive = self.isPrimitive +// override def parseAsOption( +// row: Map[String, String], +// columnName: Option[String] +// ): Either[List[XmlDecoderError], Option[A]] = { +// self.parseAsOption(row, columnName) match { +// case r @ Right(Some(value)) => +// errors(value) match { +// case Nil => r +// case errs => +// Left(errs.map(e => XmlDecoderError(e, errorColumn(options, columnName)))) +// } +// case r @ Right(None) => r +// case l @ Left(_) => l +// } + } + + final def map[B](f: A => B) = new XmlDecoder[B] { + override def isOptional = self.isOptional + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], B] = + self.parse(node, path).map(f) + } + + final def emap[B](f: A => Either[String, B]) = new XmlDecoder[B] { + override def isOptional = self.isOptional + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], B] = + self.parse(node, path) match { + case Left(errors) => Left(errors) + case Right(value) => + f(value) match { + case Left(error) => Left(List(XmlDecoderError(error, path))) + case Right(b) => Right(b) + } + } + } + +} + +object XmlDecoder { + def derive[A]()(implicit schema: Schema[A]): XmlDecoder[A] = { + new Converter().convert(schema, None) + } + + private def createPrimitive[A](f: (String, List[String]) => Either[List[XmlDecoderError], A]): XmlDecoder[A] = + new XmlDecoder[A] { + override def isOptional = false + override def parse(nodes: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = { + if (nodes.length != 1) { + Left(List(XmlDecoderError(s"Expected text value, but many nodes instead.", path))) + } + else { + val node = nodes.head + node match { + case scala.xml.Text(data) => f(data, path) + case other => + Left(List(XmlDecoderError( + s"Expected text node, but got [${other.getClass.getSimpleName}]: $other.", + path + ))) + } + } + } + } + + private def createPrimitiveFromThrowing[A](f: String => A, errorMessage: => String) = createPrimitive { + case (value, maybeColumn) => + try Right(f(value)) + catch { case NonFatal(_) => Left(List(XmlDecoderError(errorMessage, maybeColumn))) } + } + + private val unitDecoder: XmlDecoder[Unit] = createPrimitive((_, _) => Right(())) + private val boolDecoder: XmlDecoder[Boolean] = createPrimitive { (x, path) => + x.toBooleanOption match { + case Some(value) => Right(value) + case None => Left(List(XmlDecoderError(s"Cannot parse boolean (it must be either 'true' or 'false').", path))) + } + } + + private val byteDecoder: XmlDecoder[Byte] = createPrimitive { case (value, col) => + value.toByteOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse byte.", col))) + } + } + + private val shortDecoder: XmlDecoder[Short] = createPrimitive { case (value, col) => + value.toShortOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse number.", col))) + } + } + + private val intDecoder: XmlDecoder[Int] = createPrimitive { case (value, col) => + value.toIntOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse number.", col))) + } + } + + private val longDecoder: XmlDecoder[Long] = createPrimitive { case (value, col) => + value.toLongOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse number.", col))) + } + } + + private val floatDecoder: XmlDecoder[Float] = createPrimitive { case (value, col) => + value.toFloatOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse number.", col))) + } + } + + private val doubleDecoder: XmlDecoder[Double] = createPrimitive { case (value, col) => + value.toDoubleOption match { + case Some(n) => Right(n) + case None => Left(List(XmlDecoderError(s"Cannot parse number.", col))) + } + } + + private val charDecoder: XmlDecoder[Char] = createPrimitive { case (value, col) => + if (value.length == 1) Right(value.charAt(0)) + else Left(List(XmlDecoderError(s"Cannot parse char. Got string with length ${value.length} instead.", col))) + } + + private val stringDecoder: XmlDecoder[String] = createPrimitive((value, _) => Right(value)) + + private val binaryViaHexDecoder: XmlDecoder[Chunk[Byte]] = createPrimitive { (hexString, x) => + if (hexString.length % 2 != 0) { + Left(List(XmlDecoderError(s"Cannot parse hexBinary. Received a string with an odd number of characters.", x))) + } + else { + Right(Chunk.fromArray(Hex.decode(hexString))) + } + } + + private val instantDecoder: XmlDecoder[Instant] = { + createPrimitiveFromThrowing(Instant.parse, s"Cannot parse timestamp; it must be in ISO-8601 format.") + } + + private val dayOfWeekDecoder: XmlDecoder[DayOfWeek] = { + createPrimitiveFromThrowing( + DayOfWeek.valueOf, { + val expected = DayOfWeek.values().iterator.map(_.toString).mkString(", ") + s"Cannot parse day of week. Expected one of: $expected." + } + ) + } + + private val durationDecoder: XmlDecoder[Duration] = { + createPrimitiveFromThrowing( + Duration.parse, + s"Cannot parse duration; it must be in ISO-8601 format." + ) + } + + private val localTimeDecoder: XmlDecoder[LocalTime] = { + createPrimitiveFromThrowing( + LocalTime.parse, + s"""Cannot parse local time; expected text such as "10:15:00".""" + ) + } + + private val localDateTimeDecoder: XmlDecoder[LocalDateTime] = { + createPrimitiveFromThrowing( + LocalDateTime.parse, + s"""Cannot parse local date time; expected text such as "2020-12-03T10:15:30".""" + ) + } + + private val offsetTimeDecoder: XmlDecoder[OffsetTime] = { + createPrimitiveFromThrowing( + OffsetTime.parse, + s"""Cannot parse offset time; expected text such as "10:15:30+01:00".""" + ) + } + + private val offsetDateTimeDecoder: XmlDecoder[OffsetDateTime] = { + createPrimitiveFromThrowing( + OffsetDateTime.parse, + s"""Cannot parse offset date time; expected text such as "2020-12-03T10:15:30+01:00".""" + ) + } + + private val zonedDateTimeDecoder: XmlDecoder[ZonedDateTime] = { + createPrimitiveFromThrowing( + ZonedDateTime.parse, + s"""Cannot parse zoned date time; expected text such as "2020-12-03T10:15:30+01:00[Europe/Paris]".""" + ) + } + + private val bigDecimalDecoder: XmlDecoder[java.math.BigDecimal] = { + createPrimitiveFromThrowing(BigDecimal.apply(_).underlying(), s"Cannot parse BigDecimal number.") + } + + private val bigIntDecoder: XmlDecoder[java.math.BigInteger] = { + createPrimitiveFromThrowing(BigInt.apply(_).underlying(), s"Cannot parse BigInteger number.") + } + + private val localDateDecoder: XmlDecoder[LocalDate] = { + createPrimitiveFromThrowing( + LocalDate.parse, + s"Cannot parse date; it must be in ISO-8601 format (i.e. 'yyyy-MM-dd')." + ) + } + + private val uuidDecoder: XmlDecoder[UUID] = { + createPrimitiveFromThrowing(UUID.fromString, s"Cannot parse UUID.") + } + + private val zoneIdDecoder: XmlDecoder[ZoneId] = { + createPrimitiveFromThrowing(str => ZoneId.of(str), s"Cannot parse ZoneId.") + } + + private val zoneOffsetDecoder: XmlDecoder[ZoneOffset] = { + createPrimitiveFromThrowing(str => ZoneOffset.of(str), s"Cannot parse ZoneOffset.") + } + + private def decodeOption[A](d: XmlDecoder[A]): XmlDecoder[Option[A]] = + new XmlDecoder[Option[A]] { + final override def isOptional: Boolean = true + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], Option[A]] = { + if (node.isEmpty) Right(None) + else d.parse(node, path).map(Some(_)) + } + } + + private def decodeChunk[A](underlying: XmlDecoder[A], nodeName: String): XmlDecoder[Chunk[A]] = + new XmlDecoder[Chunk[A]] { + final override def isOptional: Boolean = false + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], Chunk[A]] = { + val children = node.iterator.filter(_.label == nodeName) + .map { + _.child + .map { + case Text(data) => Text(data.trim) + case other => other + } + .filter { + case Text(data) => data.nonEmpty + case _ => true + } + } + .filter(_.nonEmpty) + + val errorBuilder = ListBuffer[XmlDecoderError]() + val chunkBuilder = Chunk.newBuilder[A] + + var i = 0 + for (child <- children) { + // XPath starts with 1 + underlying.parse(child, s"[${i + 1}]" :: nodeName :: path) match { + case Right(value) => + val _ = chunkBuilder += value + case Left(errors) => + val _ = errorBuilder ++= errors + } + i += 1 + } + + if (errorBuilder.isEmpty) Right(chunkBuilder.result()) + else Left(errorBuilder.toList); + } + } + + final private[xml] class LazyDecoder[A] extends XmlDecoder[A] with ConverterCache.Lazy[XmlDecoder[A]] { + override def isOptional: Boolean = get.isOptional +// override def isPrimitive: Boolean = get.isOptional + + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = { + get.parse(node, path) + } + } + + private class Converter(cache: ConverterCache[XmlDecoder] = new ConverterCache[XmlDecoder]()) { + private def convertUsingCache[A]( + typeId: TypeId, + annotations: XmlAnnotations[A] + )(convert: => XmlDecoder[A]): XmlDecoder[A] = { + cache.convertUsingCache(typeId, annotations.openApiAnnotations)(convert)(() => new LazyDecoder[A]()) + } + + def convert[A](schema: Schema[A], fieldName: Option[String]): XmlDecoder[A] = { + schema match { + case Primitive(standardType, annotations) => + primitiveConverter(standardType, annotations) + + case Schema.Sequence(schemaA, fromChunk, _, annotations, _) => + val parsed = extractAnnotations[A](annotations) + val nodeName = parsed.xmlFieldName + .orElse(fieldName) + .getOrElse { + throw new RuntimeException("Sequence must have xmlFieldName annotation") + } + addAnnotations( + decodeChunk(convert(schemaA, Some(nodeName)), nodeName).map(fromChunk), + parsed + ) + + case Schema.Set(_, _) => + ??? + + case Schema.Transform(schema, f, _, annotations, _) => + val parsed = extractAnnotations[A](annotations) + val baseDecoder = convert(schema, parsed.xmlFieldName.orElse(fieldName)).emap(f) + addAnnotations(baseDecoder, parsed) + + case Schema.Optional(schema, annotations) => + val parsed = extractAnnotations[A](annotations) + addAnnotations[A]( + baseDecoder = decodeOption(convert(schema, parsed.xmlFieldName.orElse(fieldName))), + metadata = parsed + ) + + case l @ Schema.Lazy(_) => + convert(l.schema, fieldName) + + case record: Schema.Record[A] => + convertRecord(record) + + case s: Schema.Enum[A] => + convertEnum(s) + + case _ => + ??? + } + } + + private def convertEnum[A](schema: Schema.Enum[A]): XmlDecoder[A] = { + val enumAnnotations = extractAnnotations[A](schema.annotations) + val serDeStrategy = enumAnnotations.openApiAnnotations.sumTypeSerDeStrategy + .getOrElse { + throw new RuntimeException( + s"Discriminator must be defined to derive an XmlDecoder. Received annotations: $enumAnnotations" + ) + } + + serDeStrategy match { + case OpenApiSumTypeSerDeStrategy.Discriminator(discriminator) => + val reversedDiscriminator = discriminator.mapping.map(_.swap) + if (reversedDiscriminator.size != discriminator.mapping.size) { + throw new RuntimeException( + s"Cannot derive XmlDecoder for ${schema.id.name}, because discriminator mapping is not unique." + ) + } + + val decoderByDiscType = schema.cases.iterator + .map { c => + val decoder = addAnnotations( + convert(c.schema, None), + extractAnnotations(c.annotations) + ) + reversedDiscriminator(c.caseName) -> (decoder, c) + } + .toMap + + convertUsingCache(schema.id, enumAnnotations) { + val baseDecoder = new XmlDecoder[A] { + final override def isOptional: Boolean = false + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = { + val nodes = node.collect { case e: Elem => e } + if (nodes.size != 1) { + if (nodes.isEmpty) { + Left(List(XmlDecoderError(s"Expected single node, but none found.", path))) + } + else { + Left(List(XmlDecoderError( + s"Expected single node, but the following nodes instead: ${node.map(_.label).mkString(", ")}.", + path + ))) + } + } + else { + val e = nodes.head + decoderByDiscType.get(e.label) match { + case None => Left(List(XmlDecoderError( + s"Unrecognized object type: ${e.label}. Valid object types are: ${reversedDiscriminator.keys.mkString(", ")}.", + path + ))) + case Some((decoder, _)) => + decoder.parse(e.child, e.label :: path).asInstanceOf[Either[List[XmlDecoderError], A]] + } + } + } + } + addAnnotations(baseDecoder, enumAnnotations) + } + } + + } + + private def convertRecord[A](record: Schema.Record[A]): XmlDecoder[A] = { + val recordAnnotations: XmlAnnotations[A] = extractAnnotations[A](record.annotations) + convertUsingCache(record.id, recordAnnotations) { + val fieldDecoders = record.fields + .map { field => + val fieldAnnotations = extractAnnotations[Any](field.annotations) + addAnnotations[Any]( + convert[Any]( + field.schema.asInstanceOf[Schema[Any]], + fieldAnnotations.xmlFieldName.orElse(Some(field.name)) + ), + fieldAnnotations + ) + } + val fieldNames = record.fields.map { field => extractAnnotations[Any](field.annotations).xmlFieldName } + val isFieldSeq = record.fields.map { field => OpenApiConverterUtils.isSeq(field.schema) } + val baseDecoder = new XmlDecoder[A] { + final override def isOptional: Boolean = false + override def parse(node: NodeSeq, path: List[String]): Either[List[XmlDecoderError], A] = { + val errorBuilder = ListBuffer.empty[XmlDecoderError] + val argsBuilder = Chunk.newBuilder[Any] + var i = 0 + while (i < record.fields.length) { + val field = record.fields(i) + val fieldName = fieldNames(i).getOrElse(field.name) + val fieldDecoder = fieldDecoders(i) + val isSeq = isFieldSeq(i) + val childNodeIndex = node.indexWhere(_.label == fieldName) + if (childNodeIndex == -1) { + if (isSeq) { + fieldDecoder.parse(NodeSeq.Empty, path) match { + case Right(value) => + val _ = argsBuilder += value + case Left(errors) => + val _ = errorBuilder ++= errors + } + } + else if (fieldDecoder.isOptional) { + fieldDecoder.parse(NodeSeq.Empty, path) match { + case Right(value) => + val _ = argsBuilder += value + case Left(errors) => + val _ = errorBuilder ++= errors + } + } + else { + val _ = errorBuilder += XmlDecoderError(s"Expected field '$fieldName', but it was not found.", path) + } + } + else { + if (isSeq) { + fieldDecoder.parse(node, path) match { + case Right(value) => + val _ = argsBuilder += value + case Left(errors) => + val _ = errorBuilder ++= errors + } + } + else { + val fieldNodes = node(childNodeIndex).child + val result = fieldDecoder.parse(fieldNodes, fieldName :: path) + result match { + case Right(value) => + val _ = argsBuilder += value + case Left(errors) => + val _ = errorBuilder ++= errors + } + } + } + i += 1 + } + + if (errorBuilder.nonEmpty) Left(errorBuilder.toList) + else { + val args = argsBuilder.result() + if (args.length != record.fields.length) { + Left(List(XmlDecoderError(s"Expected ${record.fields.length} fields, but got ${args.length}.", path))) + } + else { + val result = zio.Unsafe.unsafe { implicit unsafe => + record.construct(args) match { + case Right(value) => value + case Left(error) => + throw new RuntimeException(s"Error constructing record [${record.id.name}]: $error") + } + } + Right(result) + } + } + } + } + addAnnotations(baseDecoder, recordAnnotations) + } + } + + private def primitiveConverter[A]( + standardType: StandardType[A], + annotations: Chunk[Any] + ): XmlDecoder[A] = { + val baseDecoder = standardType match { + case StandardType.UnitType => unitDecoder + case StandardType.StringType => stringDecoder + case StandardType.BoolType => boolDecoder + case StandardType.ByteType => byteDecoder + case StandardType.ShortType => shortDecoder + case StandardType.IntType => intDecoder + case StandardType.LongType => longDecoder + case StandardType.FloatType => floatDecoder + case StandardType.DoubleType => doubleDecoder + case StandardType.BinaryType => binaryViaHexDecoder + case StandardType.CharType => charDecoder + case StandardType.UUIDType => uuidDecoder + case StandardType.BigDecimalType => bigDecimalDecoder + case StandardType.BigIntegerType => bigIntDecoder + case StandardType.DayOfWeekType => dayOfWeekDecoder + case StandardType.MonthType => notSupported("MonthType") + case StandardType.MonthDayType => notSupported("MonthDayType") + case StandardType.PeriodType => notSupported("PeriodType") + case StandardType.YearType => notSupported("YearType") + case StandardType.YearMonthType => notSupported("YearMonthType") + case StandardType.ZoneIdType => zoneIdDecoder + case StandardType.ZoneOffsetType => zoneOffsetDecoder + case StandardType.DurationType => durationDecoder + case StandardType.InstantType => instantDecoder + case StandardType.LocalDateType => localDateDecoder + case StandardType.LocalTimeType => localTimeDecoder + case StandardType.LocalDateTimeType => localDateTimeDecoder + case StandardType.OffsetTimeType => offsetTimeDecoder + case StandardType.OffsetDateTimeType => offsetDateTimeDecoder + case StandardType.ZonedDateTimeType => zonedDateTimeDecoder + } + addAnnotations(baseDecoder.asInstanceOf[XmlDecoder[A]], extractAnnotations(annotations)) + } + + private def addAnnotations[A]( + baseDecoder: XmlDecoder[A], + metadata: XmlAnnotations[A] + ): XmlDecoder[A] = { + var decoder = baseDecoder + decoder = metadata.openApiAnnotations.default.fold(decoder) { case (default, _) => + decodeOption[A](decoder).map(maybeValue => maybeValue.getOrElse(default)) + } + decoder = metadata.openApiAnnotations.validator.fold(decoder) { validator: Validator[A] => + decoder.ensure { value => + validator(value).map(OpenApiValidation.errorMessage) + } + } + decoder + } + + private def notSupported(value: String) = { + throw new IllegalArgumentException( + s"Cannot convert ZIO-Schema to XmlDecoder, because $value is currently not supported." + ) + } + + } + +} diff --git a/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlEncoder.scala b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlEncoder.scala new file mode 100644 index 00000000..c19c7260 --- /dev/null +++ b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlEncoder.scala @@ -0,0 +1,395 @@ +package dev.chopsticks.xml + +import dev.chopsticks.openapi.common.{ConverterCache, OpenApiConverterUtils} +import dev.chopsticks.openapi.OpenApiSumTypeSerDeStrategy +import dev.chopsticks.util.Hex +import dev.chopsticks.xml.XmlAnnotations.extractAnnotations + +import java.time.{ + DayOfWeek, + Duration, + Instant, + LocalDate, + LocalDateTime, + LocalTime, + OffsetDateTime, + OffsetTime, + ZoneId, + ZoneOffset, + ZonedDateTime +} +import java.util.UUID +import scala.xml.{Elem, NodeSeq, TopScope} +import zio.Chunk +import zio.schema.{Schema, StandardType, TypeId} +import zio.schema.Schema.{Field, Primitive} + +import scala.annotation.nowarn + +trait XmlEncoder[A] { self => + def encode(a: A): NodeSeq + def isOptional: Boolean + final def contramap[B](f: B => A): XmlEncoder[B] = new XmlEncoder[B] { + override def isOptional: Boolean = self.isOptional + override def encode(value: B): NodeSeq = { + self.encode(f(value)) + } + } +} +object XmlEncoder { + def derive[A]()(implicit schema: Schema[A]): XmlEncoder[A] = { + new Converter().convert(schema, None) + } + + private def createPrimitiveEncoder[A](f: A => scala.xml.NodeSeq): XmlEncoder[A] = new XmlEncoder[A] { + final override def isOptional: Boolean = false + override def encode(a: A): NodeSeq = f(a) + } + + private val emptyNode = scala.xml.Text("") + + private val unitEncoder: XmlEncoder[Unit] = createPrimitiveEncoder(_ => emptyNode) + private val boolEncoder: XmlEncoder[Boolean] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val byteEncoder: XmlEncoder[Byte] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val shortEncoder: XmlEncoder[Short] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val intEncoder: XmlEncoder[Int] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val longEncoder: XmlEncoder[Long] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val floatEncoder: XmlEncoder[Float] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val doubleEncoder: XmlEncoder[Double] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val stringEncoder: XmlEncoder[String] = createPrimitiveEncoder(x => scala.xml.Text(x)) + private val charEncoder: XmlEncoder[Char] = createPrimitiveEncoder(x => scala.xml.Text(String.valueOf(x))) + private val binaryViaHexEncoder: XmlEncoder[Chunk[Byte]] = + createPrimitiveEncoder(x => scala.xml.Text(Hex.encode(x.toArray))) + private val instantEncoder: XmlEncoder[Instant] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val offsetDateTimeEncoder: XmlEncoder[OffsetDateTime] = + createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val zonedDateTimeEncoder: XmlEncoder[ZonedDateTime] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val localDateTimeEncoder: XmlEncoder[LocalDateTime] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val localDateEncoder: XmlEncoder[LocalDate] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val localTimeEncoder: XmlEncoder[LocalTime] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val offsetTimeEncoder: XmlEncoder[OffsetTime] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val zoneIdEncoder: XmlEncoder[ZoneId] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val zoneOffsetEncoder: XmlEncoder[ZoneOffset] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val bigDecimalEncoder: XmlEncoder[BigDecimal] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val bigIntEncoder: XmlEncoder[BigInt] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val uuidEncoder: XmlEncoder[UUID] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val dayOfWeekEncoder: XmlEncoder[DayOfWeek] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + private val durationEncoder: XmlEncoder[Duration] = createPrimitiveEncoder(x => scala.xml.Text(x.toString)) + + private def encodeOption[A](underlying: XmlEncoder[A]): XmlEncoder[Option[A]] = new XmlEncoder[Option[A]] { + final override def isOptional: Boolean = true + override def encode(a: Option[A]): NodeSeq = a match { + case Some(x) => underlying.encode(x) + case None => scala.xml.NodeSeq.Empty + } + } + + private def encodeChunk[A](underlying: XmlEncoder[A], nodeName: String): XmlEncoder[Chunk[A]] = + new XmlEncoder[Chunk[A]] { + final override def isOptional: Boolean = false + override def encode(xs: Chunk[A]): NodeSeq = { + val builder = scala.xml.NodeSeq.newBuilder + var i = 0 + while (i < xs.length) { + val childNodes = underlying.encode(xs(i)) + val _ = + builder += scala.xml.Elem(null, nodeName, scala.xml.Null, TopScope, minimizeEmpty = true, childNodes: _*) + i += 1 + } + builder.result() + } + } + + final private[xml] class LazyEncoder[A] extends XmlEncoder[A] with ConverterCache.Lazy[XmlEncoder[A]] { + override def isOptional: Boolean = get.isOptional + override def encode(a: A): NodeSeq = { + get.encode(a) + } + } + + private class Converter( + cache: ConverterCache[XmlEncoder] = new ConverterCache[XmlEncoder]() + ) { + private def convertUsingCache[A]( + typeId: TypeId, + annotations: XmlAnnotations[A] + )(convert: => XmlEncoder[A]): XmlEncoder[A] = { + cache.convertUsingCache(typeId, annotations.openApiAnnotations)(convert)(() => new LazyEncoder[A]()) + } + + // scalafmt: { maxColumn = 800, optIn.configStyleArguments = false } + def convert[A](schema: Schema[A], fieldName: Option[String]): XmlEncoder[A] = { + schema match { + case Primitive(standardType, annotations) => + primitiveConverter(standardType, annotations) + + case Schema.Sequence(schemaA, _, toChunk, annotations, _) => + val parsed = extractAnnotations[A](annotations) + val nodeName = parsed.xmlFieldName + .orElse(fieldName) + .getOrElse { + throw new RuntimeException("Sequence must have xmlFieldName annotation") + } + addAnnotations( + None, + encodeChunk(convert(schemaA, Some(nodeName)), nodeName).contramap(toChunk), + parsed + ) + + case Schema.Set(_, _) => + ??? + + case Schema.Transform(schema, _, g, annotations, _) => + val typedAnnotations = extractAnnotations[A](annotations) + val baseEncoder = convert(schema, typedAnnotations.xmlFieldName.orElse(fieldName)).contramap[A] { x => + g(x) match { + case Right(v) => v + case Left(error) => throw new RuntimeException(s"Couldn't transform schema: $error") + } + } + addAnnotations(None, baseEncoder, typedAnnotations) + + case Schema.Optional(schema, annotations) => + val parsed = extractAnnotations[A](annotations) + addAnnotations[A]( + None, + baseEncoder = encodeOption(convert(schema, parsed.xmlFieldName.orElse(fieldName))).asInstanceOf[XmlEncoder[A]], + metadata = extractAnnotations(annotations) + ) + + case l @ Schema.Lazy(_) => + convert(l.schema, fieldName) + + case s: Schema.Record[A] => + convertRecord[A](s.id, s.annotations, s.fields) + + case s: Schema.Enum[A] => + convertEnum[A](s) + + case _ => + ??? + + } + } + + private def convertEnum[A](schema: Schema.Enum[A]): XmlEncoder[A] = { + val enumAnnotations = extractAnnotations[A](schema.annotations) + val serDeStrategy = enumAnnotations.openApiAnnotations.sumTypeSerDeStrategy + .getOrElse { + throw new RuntimeException( + s"Discriminator must be defined to derive an XmlEncoder. Received annotations: $enumAnnotations" + ) + } + + serDeStrategy match { + case OpenApiSumTypeSerDeStrategy.Discriminator(discriminator) => + if (discriminator.mapping.size != schema.cases.size) { + throw new RuntimeException( + s"Cannot derive XmlEncoder for ${schema.id.name}, because discriminator mapping has different length than the number of cases. Discriminator mapping length = ${discriminator.mapping.size}, possible cases: ${schema.cases.map(_.caseName).mkString(", ")}." + ) + } + val reversedDiscriminator = discriminator.mapping.map(_.swap) + if (reversedDiscriminator.size != discriminator.mapping.size) { + throw new RuntimeException( + s"Cannot derive XmlEncoder for ${schema.id.name}, because discriminator mapping is not unique." + ) + } + val encoderByDiscValue = { + schema.cases.iterator + .map { c => + val encoder = addAnnotations( + None, + convert(c.schema, None), + extractAnnotations(c.annotations) + ).asInstanceOf[XmlEncoder[Any]] + reversedDiscriminator(c.caseName) -> encoder + } + .toMap + } + + convertUsingCache(schema.id, enumAnnotations) { + val baseDecoder = new XmlEncoder[A] { + final override def isOptional: Boolean = false + override def encode(a: A): NodeSeq = { + val discriminatorValue = discriminator.discriminatorValue(a) + val encoder = encoderByDiscValue(discriminatorValue) + Elem( + null, + discriminatorValue, + scala.xml.Null, + TopScope, + minimizeEmpty = true, + encoder.encode(a): _* + ) + } + } + addAnnotations(Some(schema.id), baseDecoder, enumAnnotations) + } + + } + + } + + private def convertRecord[A](id: TypeId, annotations: Chunk[Any], fields: Chunk[Field[A, _]]): XmlEncoder[A] = { + val recordAnnotations: XmlAnnotations[A] = extractAnnotations[A](annotations) + convertUsingCache(id, recordAnnotations) { + val fieldEncoders = fields + .map { field => + val fieldAnnotations = extractAnnotations[Any](field.annotations) + addAnnotations[Any]( + None, + convert[Any]( + field.schema.asInstanceOf[Schema[Any]], + fieldAnnotations.xmlFieldName.orElse(Some(field.name)) + ), + fieldAnnotations + ) + } + val fieldNames = fields.map { field => extractAnnotations[Any](field.annotations).xmlFieldName } + val isFieldSeq = fields.map { field => OpenApiConverterUtils.isSeq(field.schema) } + val baseEncoder = new XmlEncoder[A] { + final override def isOptional: Boolean = false + override def encode(value: A): NodeSeq = { + val builder = NodeSeq.newBuilder + var i = 0 + while (i < fields.length) { + val field = fields(i) + val encoder = fieldEncoders(i) + val isSeq = isFieldSeq(i) + val childNodes = encoder + .asInstanceOf[XmlEncoder[Any]] + .encode(field.get(value)) + if (!(childNodes.isEmpty && encoder.isOptional)) { + if (!isSeq) { + val fieldName = fieldNames(i).getOrElse(field.name) + val newElem = scala.xml.Elem( + null, + fieldName, + scala.xml.Null, + TopScope, + minimizeEmpty = true, + childNodes: _* + ) + val _ = builder.addOne(newElem) + } + else { + val _ = builder.addAll(childNodes) + } + } + i += 1 + } + builder.result() + } + } + addAnnotations(Some(id), baseEncoder, recordAnnotations) + } + } + +// private def convertEnum[A]( +// id: TypeId, +// annotations: Chunk[Any], +// cases: Chunk[Schema.Case[A, _]] +// ): XmlEncoder[A] = { +// val enumAnnotations = extractAnnotations[A](annotations) +// val encodersByName = cases.iterator +// .map { c => +// val cAnn = extractAnnotations(c.annotations) +// val encoder = addAnnotations( +// None, +// convert(c.schema), +// extractAnnotations(c.annotations) +// ).asInstanceOf[XmlEncoder[Any]] +// val entityName = OpenApiConverterUtils.getCaseEntityName(c, cAnn).getOrElse(throw new RuntimeException( +// s"Subtype of ${enumAnnotations.entityName.getOrElse("-")} must have entityName defined or be a case class to derive a XmlEncoder. Received annotations: $cAnn" +// )) +// entityName -> (encoder, c) +// } +// .toMap +// val discriminator = enumAnnotations.sumTypeSerDeStrategy +// +// val decoder = discriminator +// .getOrElse(throw new RuntimeException( +// s"Discriminator must be defined to derive an XmlEncoder. Received annotations: $enumAnnotations" +// )) match { +// case OpenApiSumTypeSerDeStrategy.Discriminator(discriminator) => +// val diff = discriminator.mapping.values.toSet.diff(encodersByName.keySet) +// if (diff.nonEmpty) { +// throw new RuntimeException( +// s"Cannot derive CsvEncoder for ${enumAnnotations.entityName.getOrElse("-")}, because mapping and decoders don't match. Diff=$diff." +// ) +// } +// new XmlEncoder[A] { +// override def encode( +// value: A, +// columnName: Option[String], +// acc: mutable.LinkedHashMap[String, String] +// ): mutable.LinkedHashMap[String, String] = { +// var res = acc +// val discValue = discriminator.discriminatorValue(value) +// val (enc, c) = encodersByName(discriminator.mapping(discValue)) +// val discriminatorColumnName = +// Some(options.nestedFieldLabel(columnName, discriminator.discriminatorFieldName)) +// res = stringEncoder.encode(discriminator.discriminatorFieldName, discriminatorColumnName, res) +// enc.encode(c.deconstruct(value).asInstanceOf[Any], columnName, res) +// } +// } +// } +// addAnnotations(Some(id), decoder, enumAnnotations) +// } + + private def primitiveConverter[A]( + standardType: StandardType[A], + annotations: Chunk[Any] + ): XmlEncoder[A] = { + val baseEncoder = standardType match { + case StandardType.UnitType => unitEncoder + case StandardType.StringType => stringEncoder + case StandardType.BoolType => boolEncoder + case StandardType.ByteType => byteEncoder + case StandardType.ShortType => shortEncoder + case StandardType.IntType => intEncoder + case StandardType.LongType => longEncoder + case StandardType.FloatType => floatEncoder + case StandardType.DoubleType => doubleEncoder + case StandardType.BinaryType => binaryViaHexEncoder + case StandardType.CharType => charEncoder + case StandardType.UUIDType => uuidEncoder + case StandardType.BigDecimalType => bigDecimalEncoder + case StandardType.BigIntegerType => bigIntEncoder + case StandardType.DayOfWeekType => dayOfWeekEncoder + case StandardType.MonthType => notSupported("MonthType") + case StandardType.MonthDayType => notSupported("MonthDayType") + case StandardType.PeriodType => notSupported("PeriodType") + case StandardType.YearType => notSupported("YearType") + case StandardType.YearMonthType => notSupported("YearMonthType") + case StandardType.ZoneIdType => zoneIdEncoder + case StandardType.ZoneOffsetType => zoneOffsetEncoder + case StandardType.DurationType => durationEncoder + case StandardType.InstantType => instantEncoder + case StandardType.LocalDateType => localDateEncoder + case StandardType.LocalTimeType => localTimeEncoder + case StandardType.LocalDateTimeType => localDateTimeEncoder + case StandardType.OffsetTimeType => offsetTimeEncoder + case StandardType.OffsetDateTimeType => offsetDateTimeEncoder + case StandardType.ZonedDateTimeType => zonedDateTimeEncoder + } + addAnnotations(None, baseEncoder.asInstanceOf[XmlEncoder[A]], extractAnnotations(annotations)) + } + + @nowarn("cat=unused-params") + private def addAnnotations[A]( + typeId: Option[TypeId], + baseEncoder: XmlEncoder[A], + metadata: XmlAnnotations[A] + ): XmlEncoder[A] = { + baseEncoder + } + + private def notSupported(value: String) = { + throw new IllegalArgumentException( + s"Cannot convert ZIO-Schema to XmlEncoder, because $value is currently not supported." + ) + } + + } +} diff --git a/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlModel.scala b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlModel.scala new file mode 100644 index 00000000..db2742a1 --- /dev/null +++ b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlModel.scala @@ -0,0 +1,9 @@ +package dev.chopsticks.xml + +import zio.schema.Schema + +trait XmlModel[A] { + implicit def zioSchema: Schema[A] + implicit lazy val xmlEncoder: XmlEncoder[A] = XmlEncoder.derive[A]()(zioSchema) + implicit lazy val xmlDecoder: XmlDecoder[A] = XmlDecoder.derive[A]()(zioSchema) +} diff --git a/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlPrettyPrinter.scala b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlPrettyPrinter.scala new file mode 100644 index 00000000..61a82718 --- /dev/null +++ b/chopsticks-xml/src/main/scala/dev/chopsticks/xml/XmlPrettyPrinter.scala @@ -0,0 +1,293 @@ +package dev.chopsticks.xml + +import scala.collection.Seq +import scala.xml.{ + Atom, + Comment, + EntityRef, + Group, + MinimizeMode, + NamespaceBinding, + Node, + ProcInstr, + Text, + TextBuffer, + TopScope, + Utility, + XML +} + +// copied from scala-xml +class XmlPrettyPrinter(width: Int, step: Int, minimizeEmpty: Boolean) { + + def this(width: Int, step: Int) = this(width, step, minimizeEmpty = false) + + val minimizeMode: MinimizeMode.Value = if (minimizeEmpty) MinimizeMode.Always else MinimizeMode.Default + class BrokenException() extends java.lang.Exception + + class Item + case object Break extends Item { + override def toString: String = "\\" + } + case class Box(col: Int, s: String) extends Item + case class Para(s: String) extends Item + + protected var items: List[Item] = Nil + + protected var cur: Int = 0 + + protected def reset(): Unit = { + cur = 0 + items = Nil + } + + /** Try to cut at whitespace. + */ + protected def cut(s: String, ind: Int): List[Item] = { + val tmp: Int = width - cur + if (s.length <= tmp) + return List(Box(ind, s)) + var i: Int = s.indexOf(' ') + if (i > tmp || i == -1) throw new BrokenException() // cannot break + + var last: List[Int] = Nil + while (i != -1 && i < tmp) { + last = i :: last + i = s.indexOf(' ', i + 1) + } + var res: List[Item] = Nil + while (Nil != last) + try { + val b: Box = Box(ind, s.substring(0, last.head)) + cur = ind + res = b :: Break :: cut(s.substring(last.head, s.length), ind) + // backtrack + last = last.tail + } + catch { + case _: BrokenException => last = last.tail + } + throw new BrokenException() + } + + /** Try to make indented box, if possible, else para. + */ + protected def makeBox(ind: Int, s: String): Unit = + if (cur + s.length > width) { // fits in this line + items ::= Box(ind, s) + cur += s.length + } + else + try cut(s, ind).foreach(items ::= _) // break it up + catch { case _: BrokenException => makePara(ind, s) } // give up, para + + // dont respect indent in para, but afterwards + protected def makePara(ind: Int, s: String): Unit = { + items = Break :: Para(s) :: Break :: items + cur = ind + } + + // respect indent + protected def makeBreak(): Unit = { // using wrapping here... + items = Break :: items + cur = 0 + } + + protected def leafTag(n: Node): String = { + def mkLeaf(sb: StringBuilder): Unit = { + val _ = sb.append('<') + val _ = n.nameToString(sb) + val _ = n.attributes.buildString(sb) + val _ = sb.append("/>") + } + sbToString(mkLeaf) + } + + protected def startTag(n: Node, pscope: NamespaceBinding): (String, Int) = { + var i: Int = 0 + def mkStart(sb: StringBuilder): Unit = { + val _ = sb.append('<') + val _ = n.nameToString(sb) + i = sb.length + 1 + val _ = n.attributes.buildString(sb) + val _ = n.scope.buildString(sb, pscope) + val _ = sb.append('>') + } + (sbToString(mkStart), i) + } + + protected def endTag(n: Node): String = { + def mkEnd(sb: StringBuilder): Unit = { + val _ = sb.append("') + } + sbToString(mkEnd) + } + + protected def childrenAreLeaves(n: Node): Boolean = { + def isLeaf(l: Node): Boolean = l match { + case _: Atom[?] | _: Comment | _: EntityRef | _: ProcInstr => true + case _ => false + } + n.child.forall(isLeaf) + } + + protected def fits(test: String): Boolean = + test.length < width - cur + + private def doPreserve(node: Node): Boolean = + node.attribute(XML.namespace, XML.space).exists(_.toString == XML.preserve) + + protected def traverse(node: Node, pscope: NamespaceBinding, ind: Int): Unit = node match { + case Text(s) if s.trim.isEmpty => + + case _: Atom[?] | _: Comment | _: EntityRef | _: ProcInstr => + makeBox(ind, node.toString.trim) + case Group(xs) => + traverse(xs.iterator, pscope, ind) + case _ => + val test: String = { + val sb: StringBuilder = new StringBuilder() + Utility.serialize(node, pscope, sb, stripComments = false, minimizeTags = minimizeMode) + if (doPreserve(node)) sb.toString + else TextBuffer.fromString(sb.toString).toText(0).data + } + if (childrenAreLeaves(node) && fits(test)) + makeBox(ind, test) + else { + val ((stg: String, len2: Int), etg: String) = + if (node.child.isEmpty && minimizeEmpty) { + // force the tag to be self-closing + val firstAttribute: Int = test.indexOf(' ') + val firstBreak: Int = if (firstAttribute != -1) firstAttribute else test.lastIndexOf('/') + ((test, firstBreak), "") + } + else + (startTag(node, pscope), endTag(node)) + + if (stg.length < width - cur) { // start tag fits + makeBox(ind, stg) + makeBreak() + traverse(node.child.iterator, node.scope, ind + step) + makeBox(ind, etg) + } + else if (len2 < width - cur) { + // + if (!lastwasbreak) { + val _ = sb.append('\n') + } // on windows: \r\n ? + lastwasbreak = true + cur = 0 + // while (cur < last) { + // sb.append(' ') + // cur += 1 + // } + + case Box(i, s) => + lastwasbreak = false + while (cur < i) { + val _ = sb.append(' ') + cur += 1 + } + sb.append(s) + case Para(s) => + lastwasbreak = false + sb.append(s) + case other => + throw new IllegalArgumentException("unknown item: " + other) + } + } + + // public convenience methods + + /** Returns a formatted string containing well-formed XML with given namespace to prefix mapping. + * + * @param n + * the node to be serialized + * @param pscope + * the namespace to prefix mapping + * @return + * the formatted string + */ + def format(n: Node, pscope: NamespaceBinding = TopScope): String = + sbToString(format(n, pscope, _)) + + /** Returns a formatted string containing well-formed XML. + * + * @param nodes + * the sequence of nodes to be serialized + * @param pscope + * the namespace to prefix mapping + */ + def formatNodes(nodes: Seq[Node], pscope: NamespaceBinding = TopScope): String = + sbToString(formatNodes(nodes, pscope, _)) + + /** Appends a formatted string containing well-formed XML with the given namespace to prefix mapping to the given + * stringbuffer. + * + * @param nodes + * the nodes to be serialized + * @param pscope + * the namespace to prefix mapping + * @param sb + * the string buffer to which to append to + */ + def formatNodes(nodes: Seq[Node], pscope: NamespaceBinding, sb: StringBuilder): Unit = + nodes.foreach(n => sb.append(format(n, pscope))) + + private def sbToString(f: StringBuilder => Unit): String = { + val sb: StringBuilder = new StringBuilder + f(sb) + sb.toString + } +} diff --git a/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlCodecTest.scala b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlCodecTest.scala new file mode 100644 index 00000000..ace3eff8 --- /dev/null +++ b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlCodecTest.scala @@ -0,0 +1,20 @@ +package dev.chopsticks.xml + +import dev.chopsticks.xml.XmlAnnotations.xmlFieldName +import zio.schema.{DeriveSchema, Schema} + +final case class XmlTestPerson( + name: String, + age: Option[Int], + nickname: Option[String], + @xmlFieldName("address") + addresses: List[XmlTestAddress] +) +object XmlTestPerson extends XmlModel[XmlTestPerson] { + implicit override lazy val zioSchema: Schema[XmlTestPerson] = DeriveSchema.gen +} + +final case class XmlTestAddress(city: String, street: String, zip: Option[String]) +object XmlTestAddress extends XmlModel[XmlTestAddress] { + implicit override lazy val zioSchema: Schema[XmlTestAddress] = DeriveSchema.gen +} diff --git a/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlDecoderTest.scala b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlDecoderTest.scala new file mode 100644 index 00000000..22a87229 --- /dev/null +++ b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlDecoderTest.scala @@ -0,0 +1,84 @@ +package dev.chopsticks.xml + +import org.scalatest.matchers.should.Matchers +import org.scalatest.Assertions +import org.scalatest.wordspec.AnyWordSpecLike + +final class XmlDecoderTest extends AnyWordSpecLike with Assertions with Matchers { + "XmlDecoder" should { + "decode a simple case class" in { + val xml = + NY + 1st street + val decoded = XmlTestAddress.xmlDecoder.parse(xml) + val expected = XmlTestAddress("NY", "1st street", None) + assert(decoded == Right(expected)) + } + + "decode a nested case class" in { + val xml = { + John + 30 +
+ NY + 1st street +
+
+ LA + 2nd street + 12345 +
+ } + val expected = XmlTestPerson( + name = "John", + age = Some(30), + nickname = None, + addresses = List( + XmlTestAddress("NY", "1st street", None), + XmlTestAddress("LA", "2nd street", Some("12345")) + ) + ) + val decoded = XmlTestPerson.xmlDecoder.parse(xml) + assert(decoded == Right(expected)) + } + + "decode a nested case class with a single element list" in { + val xml = { + John + 30 +
+ NY + 1st street +
+ } + val expected = XmlTestPerson( + name = "John", + age = Some(30), + nickname = None, + addresses = List( + XmlTestAddress("NY", "1st street", None) + ) + ) + val decoded = XmlTestPerson.xmlDecoder.parse(xml) + assert(decoded == Right(expected)) + } + + "decode a nested case class with an empty list" in { + val xml = { + John + 30 +
+
+ } + val expected = XmlTestPerson( + name = "John", + age = Some(30), + nickname = None, + addresses = List.empty + ) + val decoded = XmlTestPerson.xmlDecoder.parse(xml) + assert(decoded == Right(expected)) + } + + } +} diff --git a/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlEncoderTest.scala b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlEncoderTest.scala new file mode 100644 index 00000000..5b32bf0f --- /dev/null +++ b/chopsticks-xml/src/test/scala/dev/chopsticks/xml/XmlEncoderTest.scala @@ -0,0 +1,73 @@ +package dev.chopsticks.xml + +import org.scalatest.wordspec.AnyWordSpecLike +import org.scalatest.Assertions +import org.scalatest.matchers.should.Matchers + +import scala.xml.{Elem, Node, NodeBuffer, Null, Text, TopScope, Utility} + +final class XmlEncoderTest extends AnyWordSpecLike with Assertions with Matchers { + + "XmlEncoder" should { + "encode a simple case class" in { + val address = XmlTestAddress("NY", "1st street", None) + val encoded = renderAsXml(address) + val expected = + NY + 1st street + + assert(encoded == renderAsXml(expected)) + } + + "encode a nested case class" in { + val person = XmlTestPerson( + name = "John", + age = Some(30), + nickname = None, + addresses = List( + XmlTestAddress("NY", "1st street", None), + XmlTestAddress("LA", "2nd street", Some("12345")) + ) + ) + val encoded = renderAsXml(person) + val expected = { + John + 30 +
+ NY + 1st street +
+
+ LA + 2nd street + 12345 +
+ } + val expectedEncoded = renderAsXml(expected) + assert(expectedEncoded == encoded) + } + } + + private def renderAsXml[A: XmlEncoder](value: A): String = { + val nodes = implicitly[XmlEncoder[A]].encode(value) + val wrapped = Elem(null, "root", Null, TopScope, minimizeEmpty = true, nodes: _*) + val cleanedXml = trimTextNodes(wrapped) + Utility.serialize(cleanedXml).result() + } + + private def renderAsXml(value: NodeBuffer): String = { + val wrapped = Elem(null, "root", Null, TopScope, minimizeEmpty = true, value: _*) + val cleanedXml = trimTextNodes(wrapped) + Utility.serialize(cleanedXml).result() + } + + // Helper function to recursively trim text nodes + private def trimTextNodes(node: Node): Node = node match { + case Elem(prefix, label, attribs, scope, child @ _*) => + Elem(prefix, label, attribs, scope, minimizeEmpty = true, child.map(trimTextNodes): _*) + case Text(data) => + Text(data.trim) + case other => other + } + +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 28e3a774..99048655 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -229,6 +229,8 @@ object Dependencies { val commonsText = Seq("org.apache.commons" % "commons-text" % "1.11.0") + val scalaXml = Seq("org.scala-lang.modules" %% "scala-xml" % "1.3.0") + lazy val tapirDeps = { val tapirVersion = "1.9.10" Seq(