Skip to content

Commit

Permalink
DynamicSchemaIndex improvements (#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baccata authored Apr 17, 2023
1 parent 2c046c8 commit 20f42e7
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import smithy4s.Document._
import smithy4s.http.PayloadError
import smithy4s.schema._
import smithy4s.schema.Primitive._
import scala.collection.immutable.ListMap

trait DocumentDecoder[A] { self =>
def apply(history: List[PayloadPath.Segment], document: Document): A
Expand Down Expand Up @@ -226,7 +227,7 @@ class DocumentDecoderSchemaVisitor(
maybeKeyDecoder match {
case Some(keyDecoder) =>
DocumentDecoder.instance("Map", "Object") { case (pp, DObject(map)) =>
val builder = Map.newBuilder[K, V]
val builder = ListMap.newBuilder[K, V]
map.foreach { case (key, value) =>
val decodedKey = keyDecoder(DString(key)).fold(
{ case DocumentKeyDecoder.DecodeError(expectedType) =>
Expand Down
21 changes: 13 additions & 8 deletions modules/dynamic/src-jvm/NodeToDocument.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package smithy4s.dynamic

import smithy4s.Document
import software.amazon.smithy.model.node._
import scala.collection.immutable.ListMap
import scala.jdk.CollectionConverters._
import java.util.function.BiConsumer

object NodeToDocument {

def apply(node: Node): Document =
return node.accept(new NodeVisitor[Document] {
return node.accept(new NodeVisitor[Document] { self =>
def arrayNode(x: ArrayNode): Document =
Document.array(x.getElements().asScala.map(_.accept(this)))

Expand All @@ -37,14 +39,17 @@ object NodeToDocument {
Document.fromDouble(x.getValue().doubleValue())

def objectNode(x: ObjectNode): Document =
Document.obj(
Document.DObject {
val builder = ListMap.newBuilder[String, Document]
x.getMembers()
.asScala
.map { case (key, value) =>
key.getValue() -> value.accept(this)
}
.toSeq: _*
)
.forEach(new BiConsumer[StringNode, Node] {
def accept(key: StringNode, value: Node): Unit = {
val kv = (key.getValue(), value.accept(self))
builder += kv
}
})
builder.result()
}

def stringNode(x: StringNode): Document =
Document.fromString(x.getValue())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ trait DynamicSchemaIndex {
def getService(shapeId: ShapeId): Option[DynamicSchemaIndex.ServiceWrapper] =
allServices.find(_.service.id == shapeId)

def allSchemas: Vector[Schema[_]]
def getSchema(shapeId: ShapeId): Option[Schema[_]]

def metadata: Map[String, Document]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ private[dynamic] object Compiler {
id,
shape.traits, {
val lFields = {
shape.members.zipWithIndex
.map { case ((label, mShape), index) =>
shape.members.toVector.zipWithIndex.map {
case ((label, mShape), index) =>
val lMemberSchema = schema(mShape.target)
val lField =
if (
Expand All @@ -495,9 +495,7 @@ private[dynamic] object Compiler {
}
val memberHints = allHints(mShape.traits)
lField.map(_.addHints(memberHints.all.toSeq: _*))
}
.toVector
.sequence
}.sequence
}
if (isRecursive(id)) {
Eval.later(recursive(struct(lFields.value)(Constructor)))
Expand All @@ -510,15 +508,13 @@ private[dynamic] object Compiler {
id,
shape.traits, {
val lAlts =
shape.members.zipWithIndex
.map { case ((label, mShape), index) =>
shape.members.toVector.zipWithIndex.map {
case ((label, mShape), index) =>
val memberHints = allHints(mShape.traits)
schema(mShape.target)
.map(_.oneOf[DynAlt](label, Injector(index)))
.map(_.addHints(memberHints))
}
.toVector
.sequence
}.sequence
if (isRecursive(id)) {
Eval.later(recursive {
val alts = lAlts.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ private[internals] class DynamicSchemaIndexImpl(

def allServices: List[DynamicSchemaIndex.ServiceWrapper] =
serviceMap.values.toList
def allSchemas: Vector[Schema[_]] =
schemaMap.values.toVector

def getSchema(
shapeId: ShapeId
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2021-2022 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithy4s.dynamic

import org.scalacheck.Gen
import org.scalacheck.Prop.forAll
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.Model
import smithy4s.schema.Schema
import org.scalacheck.Shrink

class FieldOrderingSpec extends munit.ScalaCheckSuite {

property("Ordering of fields is retained") {
val genStrings = Gen.listOfN(10, Gen.identifier).map(_.distinct)
implicit def noShrink[A]: Shrink[A] = Shrink.shrinkAny
forAll(genStrings) { (names: List[String]) =>
val structBuilder = StructureShape.builder().id("foo#Foo")
val unionBuilder = UnionShape.builder().id("foo#Bar")
names.foreach { name =>
structBuilder.addMember(name, ShapeId.from("smithy.api#Integer"))
unionBuilder.addMember(name, ShapeId.from("smithy.api#Integer"))
}
val struct = structBuilder.build()
val union = unionBuilder.build()
val model = Model.builder().addShapes(struct, union).build()
val schemaIndex = DynamicSchemaIndex
.loadModel(model)
.toOption
.get

for {
id <- List("Foo", "Bar")
} {
val schema = schemaIndex
.getSchema(smithy4s.ShapeId("foo", id))
.getOrElse(fail(s"Error: $id shape missing"))

schema match {
case Schema.StructSchema(_, _, fields, _) =>
val fieldNames = fields.map(_.label).toList
assertEquals(fieldNames, names)
case Schema.UnionSchema(_, _, alts, _) =>
val altNames = alts.map(_.label).toList
assertEquals(altNames, names)
case unexpected => fail("Unexpected schema: " + unexpected)
}
}
}
}

}

0 comments on commit 20f42e7

Please sign in to comment.