diff --git a/modules/json/src/smithy4s/http/json/Cursor.scala b/modules/json/src/smithy4s/http/json/Cursor.scala index 0dc6e513e..7e1b05224 100644 --- a/modules/json/src/smithy4s/http/json/Cursor.scala +++ b/modules/json/src/smithy4s/http/json/Cursor.scala @@ -22,63 +22,84 @@ import com.github.plokhotnyuk.jsoniter_scala.core.JsonReaderException import smithy4s.http.PayloadError class Cursor private () { - private[this] var stack: Array[PayloadPath.Segment] = - new Array[PayloadPath.Segment](8) - private[this] var top: Int = 0 - private var expecting: String = null + private[this] var indexStack: Array[Int] = new Array[Int](8) + private[this] var labelStack: Array[String] = new Array[String](8) + private[this] var top: Int = _ + private var expecting: String = _ def decode[A](codec: JCodec[A], in: JsonReader): A = { this.expecting = codec.expecting codec.decodeValue(this, in) } - def under[A](segment: PayloadPath.Segment)(f: => A): A = { - if (top >= stack.length) stack = java.util.Arrays.copyOf(stack, top << 1) - stack(top) = segment - top += 1 + def under[A](segment: PayloadPath.Segment)(f: => A): A = + segment match { + case i: PayloadPath.Segment.Index => under(i.index)(f) + case l: PayloadPath.Segment.Label => under(l.label)(f) + } + + def under[A](label: String)(f: => A): A = { + push(label) val res = f - top -= 1 + pop() res } - def under[A](label: String)(f: => A): A = - under(new PayloadPath.Segment.Label(label))(f) + def under[A](index: Int)(f: => A): A = { + push(index) + val res = f + pop() + res + } - def under[A](index: Int)(f: => A): A = - under(new PayloadPath.Segment.Index(index))(f) + def push(label: String): Unit = { + if (top >= labelStack.length) growStacks() + labelStack(top) = label + top += 1 + } + + def push(index: Int): Unit = { + if (top >= indexStack.length) growStacks() + indexStack(top) = index + top += 1 + } + + def pop(): Unit = top -= 1 def payloadError[A](codec: JCodec[A], message: String): Nothing = - throw PayloadError(getPath(), codec.expecting, message) + throw new PayloadError(getPath(Nil), codec.expecting, message) def requiredFieldError[A](codec: JCodec[A], field: String): Nothing = requiredFieldError(codec.expecting, field) def requiredFieldError[A](expecting: String, field: String): Nothing = { - var top = this.top - if (top >= stack.length) stack = java.util.Arrays.copyOf(stack, top << 1) - stack(top) = new PayloadPath.Segment.Label(field) - top += 1 - var list: List[PayloadPath.Segment] = Nil - while (top > 0) { - top -= 1 - list = stack(top) :: list - } - throw PayloadError(PayloadPath(list), expecting, "Missing required field") + val path = getPath(new PayloadPath.Segment.Label(field) :: Nil) + throw new PayloadError(path, expecting, "Missing required field") } - private def getPath(): PayloadPath = { + private def getPath(segments: List[PayloadPath.Segment]): PayloadPath = { var top = this.top - var list: List[PayloadPath.Segment] = Nil + var list = segments while (top > 0) { top -= 1 - list = stack(top) :: list + val label = labelStack(top) + val segment = + if (label ne null) new PayloadPath.Segment.Label(label) + else new PayloadPath.Segment.Index(indexStack(top)) + list = segment :: list } - PayloadPath(list) + new PayloadPath(list) } private def getExpected(): String = if (expecting != null) expecting else throw new IllegalStateException("Expected should have been fulfilled") + + private[this] def growStacks(): Unit = { + val size = top << 1 + labelStack = java.util.Arrays.copyOf(labelStack, size) + indexStack = java.util.Arrays.copyOf(indexStack, size) + } } object Cursor { @@ -86,14 +107,13 @@ object Cursor { def withCursor[A](expecting: String)(f: Cursor => A): A = { val cursor = new Cursor() cursor.expecting = expecting - try { - f(cursor) - } catch { + try f(cursor) + catch { case e: JsonReaderException => payloadError(cursor, e.getMessage()) case e: ConstraintError => payloadError(cursor, e.message) } } private[this] def payloadError(cursor: Cursor, message: String): Nothing = - throw PayloadError(cursor.getPath(), cursor.getExpected(), message) + throw new PayloadError(cursor.getPath(Nil), cursor.getExpected(), message) } diff --git a/modules/json/src/smithy4s/http/json/SchemaVisitorJCodec.scala b/modules/json/src/smithy4s/http/json/SchemaVisitorJCodec.scala index f7ef285d2..22b088875 100644 --- a/modules/json/src/smithy4s/http/json/SchemaVisitorJCodec.scala +++ b/modules/json/src/smithy4s/http/json/SchemaVisitorJCodec.scala @@ -462,7 +462,9 @@ private[smithy4s] class SchemaVisitorJCodec( var i = 0 while ({ if (i >= maxArity) maxArityError(cursor) - builder += cursor.under(i)(cursor.decode(a, in)) + cursor.push(i) + builder += cursor.decode(a, in) + cursor.pop() i += 1 in.isNextToken(',') }) () @@ -512,7 +514,9 @@ private[smithy4s] class SchemaVisitorJCodec( var i = 0 while ({ if (i >= maxArity) maxArityError(cursor) - builder += cursor.under(i)(cursor.decode(a, in)) + cursor.push(i) + builder += cursor.decode(a, in) + cursor.pop() i += 1 in.isNextToken(',') }) () @@ -559,7 +563,9 @@ private[smithy4s] class SchemaVisitorJCodec( var i = 0 while ({ if (i >= maxArity) maxArityError(cursor) - put(cursor.under(i)(cursor.decode(a, in))) + cursor.push(i) + put(cursor.decode(a, in)) + cursor.pop() i += 1 in.isNextToken(',') }) () @@ -616,7 +622,9 @@ private[smithy4s] class SchemaVisitorJCodec( var i = 0 while ({ if (i >= maxArity) maxArityError(cursor) - builder += cursor.under(i)(cursor.decode(a, in)) + cursor.push(i) + builder += cursor.decode(a, in) + cursor.pop() i += 1 in.isNextToken(',') }) () @@ -663,8 +671,12 @@ private[smithy4s] class SchemaVisitorJCodec( if (i >= maxArity) maxArityError(cursor) builder += ( ( - jk.decodeKey(in), - cursor.under(i)(cursor.decode(jv, in)) + jk.decodeKey(in), { + cursor.push(i) + val result = cursor.decode(jv, in) + cursor.pop() + result + } ) ) i += 1 @@ -798,11 +810,11 @@ private[smithy4s] class SchemaVisitorJCodec( else { in.rollbackToken() val key = in.readKeyAsString() - val result = cursor.under(key) { - val handler = handlerMap.get(key) - if (handler eq null) in.discriminatorValueError(key) - handler(cursor, in) - } + cursor.push(key) + val handler = handlerMap.get(key) + if (handler eq null) in.discriminatorValueError(key) + val result = handler(cursor, in) + cursor.pop() if (in.isNextToken('}')) result else { in.rollbackToken() @@ -931,11 +943,12 @@ private[smithy4s] class SchemaVisitorJCodec( val key = in.readString("") in.rollbackToMark() in.rollbackToken() - cursor.under(key) { - val handler = handlerMap.get(key) - if (handler eq null) in.discriminatorValueError(key) - handler(cursor, in) - } + cursor.push(key) + val handler = handlerMap.get(key) + if (handler eq null) in.discriminatorValueError(key) + val result = handler(cursor, in) + cursor.pop() + result } else in.decodeError( s"Unable to find discriminator ${discriminated.value}" @@ -1071,14 +1084,23 @@ private[smithy4s] class SchemaVisitorJCodec( val codec = apply(field.instance) val label = field.label if (field.isRequired) { (cursor, in, mmap) => - val _ = mmap.put(label, cursor.under(label)(cursor.decode(codec, in))) + val _ = mmap.put( + label, { + cursor.push(label) + val result = cursor.decode(codec, in) + cursor.pop() + result + } + ) } else { (cursor, in, mmap) => - cursor.under[Unit](label) { + { + cursor.push(label) if (in.isNextToken('n')) in.readNullOrError[Unit]((), "Expected null") else { in.rollbackToken() val _ = mmap.put(label, cursor.decode(codec, in)) } + cursor.pop() } } }