From ec62a3366530f81fea023d767d0e933813037da1 Mon Sep 17 00:00:00 2001 From: ghik Date: Wed, 13 Jun 2018 13:23:55 +0200 Subject: [PATCH 1/4] BigInteger/BigDecimal native support in Input/Output --- .../commons/ser/GenCodecBenchmarks.scala | 2 + .../commons/ser/CirceJsonInputOutput.scala | 9 ++- .../commons/serialization/GenCodec.scala | 8 +-- .../commons/serialization/InputOutput.scala | 4 ++ .../SimpleValueInputOutput.scala | 67 ++++++++++--------- .../serialization/StreamInputOutput.scala | 61 ++++++++++++++--- .../serialization/json/JsonStringInput.scala | 7 +- .../serialization/json/JsonStringOutput.scala | 2 + .../serialization/StreamInputOutputTest.scala | 14 ++-- .../json/JsonStringInputOutputTest.scala | 3 +- .../json/SerializationTestUtils.scala | 8 ++- .../commons/mongo/BsonInputOutput.scala | 20 ++++++ .../commons/mongo/BsonReaderInput.scala | 2 + .../commons/mongo/BsonValueOutput.scala | 2 + .../commons/mongo/BsonWriterOutput.scala | 8 +++ .../mongo/BigDecimalEncodingTest.scala | 14 ++++ 16 files changed, 172 insertions(+), 59 deletions(-) create mode 100644 commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala diff --git a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala index e01f17402..0c1fe453a 100644 --- a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala +++ b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala @@ -39,5 +39,7 @@ object DummyInput extends Input { def readList() = ignored def readBoolean() = ignored def readDouble() = ignored + def readBigInteger() = ignored + def readBigDecimal() = ignored def skip() = () } diff --git a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala index dc1617c9c..88206ede5 100644 --- a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala +++ b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala @@ -22,7 +22,9 @@ class CirceJsonOutput(consumer: Json => Any) extends Output { def writeInt(int: Int): Unit = consumer(Json.fromInt(int)) def writeLong(long: Long): Unit = consumer(Json.fromLong(long)) def writeDouble(double: Double): Unit = consumer(Json.fromDoubleOrString(double)) - def writeBinary(binary: Array[Byte]): Unit = ??? + def writeBigInteger(bigInteger: JBigInteger): Unit = consumer(Json.fromBigInt(BigInt(bigInteger))) + def writeBigDecimal(bigDecimal: JBigDecimal): Unit = consumer(Json.fromBigDecimal(BigDecimal(bigDecimal))) + def writeBinary(binary: Array[Byte]): Unit = consumer(Json.fromValues(binary.map(Json.fromInt(_)))) def writeList(): ListOutput = new CirceJsonListOutput(consumer) def writeObject(): ObjectOutput = new CirceJsonObjectOutput(consumer) override def writeFloat(float: Float): Unit = consumer(Json.fromFloatOrString(float)) @@ -67,7 +69,10 @@ class CirceJsonInput(json: Json) extends Input { def readInt(): Int = asNumber.toInt.getOrElse(failNot("int")) def readLong(): Long = asNumber.toLong.getOrElse(failNot("long")) def readDouble(): Double = asNumber.toDouble - def readBinary(): Array[Byte] = ??? + def readBigInteger(): JBigInteger = asNumber.toBigInt.getOrElse(failNot("bigInteger")).bigInteger + def readBigDecimal(): JBigDecimal = asNumber.toBigDecimal.getOrElse(failNot("bigDecimal")).bigDecimal + def readBinary(): Array[Byte] = json.asArray.getOrElse(failNot("array")).iterator + .map(_.asNumber.flatMap(_.toByte).getOrElse(failNot("byte"))).toArray def readList(): ListInput = new CirceJsonListInput(json.asArray.getOrElse(failNot("array"))) def readObject(): ObjectInput = new CirceJsonObjectInput(json.asObject.getOrElse(failNot("object"))) def skip(): Unit = () diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala index d658b317b..3ee707d1c 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala @@ -287,8 +287,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val LongCodec: GenCodec[Long] = create(_.readLong(), _ writeLong _) implicit lazy val FloatCodec: GenCodec[Float] = create(_.readFloat(), _ writeFloat _) implicit lazy val DoubleCodec: GenCodec[Double] = create(_.readDouble(), _ writeDouble _) - implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(i => BigInt(i.readString()), (o, v) => o.writeString(v.toString)) - implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(i => BigDecimal(i.readString()), (o, v) => o.writeString(v.toString)) + implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(i => BigInt(i.readBigInteger()), (o, v) => o.writeBigInteger(v.bigInteger)) + implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(i => BigDecimal(i.readBigDecimal()), (o, v) => o.writeBigDecimal(v.bigDecimal)) implicit lazy val JBooleanCodec: GenCodec[JBoolean] = createNullable(_.readBoolean(), _ writeBoolean _) implicit lazy val JCharacterCodec: GenCodec[JCharacter] = createNullable(_.readChar(), _ writeChar _) @@ -298,8 +298,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val JLongCodec: GenCodec[JLong] = createNullable(_.readLong(), _ writeLong _) implicit lazy val JFloatCodec: GenCodec[JFloat] = createNullable(_.readFloat(), _ writeFloat _) implicit lazy val JDoubleCodec: GenCodec[JDouble] = createNullable(_.readDouble(), _ writeDouble _) - implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = createNullable(i => new JBigInteger(i.readString()), (o, v) => o.writeString(v.toString)) - implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = createNullable(i => new JBigDecimal(i.readString()), (o, v) => o.writeString(v.toString)) + implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = createNullable(_.readBigInteger(), _ writeBigInteger _) + implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = createNullable(_.readBigDecimal(), _ writeBigDecimal _) implicit lazy val JDateCodec: GenCodec[JDate] = createNullable(i => new JDate(i.readTimestamp()), (o, d) => o.writeTimestamp(d.getTime)) implicit lazy val StringCodec: GenCodec[String] = createNullable(_.readString(), _ writeString _) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala index 559da7085..55383d040 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala @@ -22,6 +22,8 @@ trait Output extends Any { def writeTimestamp(millis: Long): Unit = writeLong(millis) def writeFloat(float: Float): Unit = writeDouble(float) def writeDouble(double: Double): Unit + def writeBigInteger(bigInteger: JBigInteger): Unit + def writeBigDecimal(bigDecimal: JBigDecimal): Unit def writeBinary(binary: Array[Byte]): Unit def writeList(): ListOutput def writeObject(): ObjectOutput @@ -139,6 +141,8 @@ trait Input extends Any { def readTimestamp(): Long = readLong() def readFloat(): Float = readDouble().toFloat def readDouble(): Double + def readBigInteger(): JBigInteger + def readBigDecimal(): JBigDecimal def readBinary(): Array[Byte] def readList(): ListInput def readObject(): ObjectInput diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala index ac7571160..3d38ae6fd 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala @@ -28,6 +28,8 @@ object SimpleValueOutput { * - `Int` * - `Long` * - `Double` + * - `java.math.BigInteger` + * - `java.math.BigDecimal` * - `Boolean` * - `String` * - `Array[Byte]` @@ -49,27 +51,27 @@ class SimpleValueOutput( def this(consumer: Any => Unit) = this(consumer, new MHashMap[String, Any], new ListBuffer[Any]) - def writeBinary(binary: Array[Byte]) = consumer(binary) - def writeString(str: String) = consumer(str) - def writeDouble(double: Double) = consumer(double) - def writeInt(int: Int) = consumer(int) - - def writeList() = new ListOutput { + def writeNull(): Unit = consumer(null) + def writeBoolean(boolean: Boolean): Unit = consumer(boolean) + def writeString(str: String): Unit = consumer(str) + def writeInt(int: Int): Unit = consumer(int) + def writeLong(long: Long): Unit = consumer(long) + def writeDouble(double: Double): Unit = consumer(double) + def writeBigInteger(bigInteger: JBigInteger): Unit = consumer(bigInteger) + def writeBigDecimal(bigDecimal: JBigDecimal): Unit = consumer(bigDecimal) + def writeBinary(binary: Array[Byte]): Unit = consumer(binary) + + def writeList(): ListOutput = new ListOutput { private val buffer = newListRepr def writeElement() = new SimpleValueOutput(buffer += _, newObjectRepr, newListRepr) - def finish() = consumer(buffer.result()) + def finish(): Unit = consumer(buffer.result()) } - def writeBoolean(boolean: Boolean) = consumer(boolean) - - def writeObject() = new ObjectOutput { + def writeObject(): ObjectOutput = new ObjectOutput { private val result = newObjectRepr def writeField(key: String) = new SimpleValueOutput(v => result += ((key, v)), newObjectRepr, newListRepr) - def finish() = consumer(result) + def finish(): Unit = consumer(result) } - - def writeLong(long: Long) = consumer(long) - def writeNull() = consumer(null) } object SimpleValueInput { @@ -91,41 +93,42 @@ class SimpleValueInput(value: Any) extends Input { case _ => throw new ReadFailure(s"Expected ${classTag[B].runtimeClass} but got ${value.getClass}") } - def inputType = value match { + def inputType: InputType = value match { case null => InputType.Null case _: BSeq[Any] => InputType.List case _: BMap[_, Any] => InputType.Object case _ => InputType.Simple } - def readBinary() = doRead[Array[Byte]] - def readLong() = doReadUnboxed[Long, JLong] - def readNull() = if (value == null) null else throw new ReadFailure("not null") - def readObject() = + def readNull(): Null = if (value == null) null else throw new ReadFailure("not null") + def readBoolean(): Boolean = doReadUnboxed[Boolean, JBoolean] + def readString(): String = doRead[String] + def readInt(): Int = doReadUnboxed[Int, JInteger] + def readLong(): Long = doReadUnboxed[Long, JLong] + def readDouble(): Double = doReadUnboxed[Double, JDouble] + def readBigInteger(): JBigInteger = doRead[JBigInteger] + def readBigDecimal(): JBigDecimal = doRead[JBigDecimal] + def readBinary(): Array[Byte] = doRead[Array[Byte]] + + def readObject(): ObjectInput = new ObjectInput { private val map = doRead[BMap[String, Any]] private val it = map.iterator.map { case (k, v) => new SimpleValueFieldInput(k, v) } - def nextField() = it.next() - override def peekField(name: String) = map.getOpt(name).map(new SimpleValueFieldInput(name, _)) - def hasNext = it.hasNext + def nextField(): SimpleValueFieldInput = it.next() + override def peekField(name: String): Opt[SimpleValueFieldInput] = map.getOpt(name).map(new SimpleValueFieldInput(name, _)) + def hasNext: Boolean = it.hasNext } - def readInt() = doReadUnboxed[Int, JInteger] - def readString() = doRead[String] - - def readList() = + def readList(): ListInput = new ListInput { private val it = doRead[BSeq[Any]].iterator.map(new SimpleValueInput(_)) - def nextElement() = it.next() - def hasNext = it.hasNext + def nextElement(): SimpleValueInput = it.next() + def hasNext: Boolean = it.hasNext } - def readBoolean() = doReadUnboxed[Boolean, JBoolean] - def readDouble() = doReadUnboxed[Double, JDouble] - - def skip() = () + def skip(): Unit = () } class SimpleValueFieldInput(val fieldName: String, value: Any) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala index 7cab613f9..b716715cc 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala @@ -29,6 +29,8 @@ private object FormatConstants { final val ObjectStartMarker: Byte = 11 final val ListEndMarker: Byte = 12 final val ObjectEndMarker: Byte = 13 + final val BigIntegerMarker: Byte = 14 + final val BigDecimalMarker: Byte = 15 } import com.avsystem.commons.serialization.FormatConstants._ @@ -50,50 +52,68 @@ class StreamInput(is: DataInputStream) extends Input { def readNull(): Null = if (markerByte == NullMarker) null else - throw new ReadFailure(s"Expected null, but $markerByte found") + throw new ReadFailure(s"Expected null but $markerByte found") def readString(): String = if (markerByte == StringMarker) is.readUTF() else - throw new ReadFailure(s"Expected string, but $markerByte found") + throw new ReadFailure(s"Expected string but $markerByte found") def readBoolean(): Boolean = if (markerByte == BooleanMarker) is.readBoolean() else - throw new ReadFailure(s"Expected boolean, but $markerByte found") + throw new ReadFailure(s"Expected boolean but $markerByte found") def readInt(): Int = if (markerByte == IntMarker) is.readInt() else - throw new ReadFailure(s"Expected int, but $markerByte found") + throw new ReadFailure(s"Expected int but $markerByte found") def readLong(): Long = if (markerByte == LongMarker) is.readLong() else - throw new ReadFailure(s"Expected long, but $markerByte found") + throw new ReadFailure(s"Expected long but $markerByte found") def readDouble(): Double = if (markerByte == DoubleMarker) is.readDouble() else - throw new ReadFailure(s"Expected double, but $markerByte found") + throw new ReadFailure(s"Expected double but $markerByte found") + + def readBigInteger(): JBigInteger = if (markerByte == BigIntegerMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + new JBigInteger(bytes) + } else + throw new ReadFailure(s"Expected big integer but $markerByte found") + + def readBigDecimal(): JBigDecimal = if (markerByte == BigDecimalMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + val unscaled = new JBigInteger(bytes) + val scale = is.readInt() + new JBigDecimal(unscaled, scale) + } else + throw new ReadFailure(s"Expected big decimal but $markerByte found") def readBinary(): Array[Byte] = if (markerByte == ByteArrayMarker) { val binary = Array.ofDim[Byte](is.readInt()) is.readFully(binary) binary } else { - throw new ReadFailure(s"Expected binary array, but $markerByte found") + throw new ReadFailure(s"Expected binary array but $markerByte found") } def readList(): ListInput = if (markerByte == ListStartMarker) new StreamListInput(is) else - throw new ReadFailure(s"Expected list, but $markerByte found") + throw new ReadFailure(s"Expected list but $markerByte found") def readObject(): ObjectInput = if (markerByte == ObjectStartMarker) new StreamObjectInput(is) else - throw new ReadFailure(s"Expected object, but $markerByte found") + throw new ReadFailure(s"Expected object but $markerByte found") def skip(): Unit = markerByte match { case NullMarker => @@ -121,6 +141,10 @@ class StreamInput(is: DataInputStream) extends Input { new StreamListInput(is).skipRemaining() case ObjectStartMarker => new StreamObjectInput(is).skipRemaining() + case BigIntegerMarker => + is.skipBytes(is.readInt()) + case BigDecimalMarker => + is.skipBytes(is.readInt() + 4) case unexpected => throw new ReadFailure(s"Unexpected marker byte: $unexpected") } @@ -182,14 +206,16 @@ private object StreamObjectInput { case class EmptyFieldInput(name: String) extends FieldInput { private def nope: Nothing = throw new ReadFailure(s"Something went horribly wrong ($name)") - def fieldName: String = nope def inputType: InputType = nope + def fieldName: String = nope def readNull(): Null = nope def readString(): String = nope def readBoolean(): Boolean = nope def readInt(): Int = nope def readLong(): Long = nope def readDouble(): Double = nope + def readBigInteger(): JBigInteger = nope + def readBigDecimal(): JBigDecimal = nope def readBinary(): Array[Byte] = nope def readList(): ListInput = nope def readObject(): ObjectInput = nope @@ -232,6 +258,21 @@ class StreamOutput(os: DataOutputStream) extends Output { os.writeDouble(double) } + def writeBigInteger(bigInteger: JBigInteger): Unit = { + os.writeByte(BigIntegerMarker) + val bytes = bigInteger.toByteArray + os.writeInt(bytes.length) + os.write(bytes) + } + + def writeBigDecimal(bigDecimal: JBigDecimal): Unit = { + os.writeByte(BigDecimalMarker) + val bytes = bigDecimal.unscaledValue.toByteArray + os.writeInt(bytes.length) + os.write(bytes) + os.writeInt(bigDecimal.scale) + } + def writeBinary(binary: Array[Byte]): Unit = { os.writeByte(ByteArrayMarker) os.writeInt(binary.length) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala index f065c0584..d942997d7 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala @@ -33,7 +33,8 @@ class JsonStringInput(reader: JsonReader, callback: AfterElement = AfterElementN case _ => afterElement() } - private def expectedError(tpe: JsonType) = throw new ReadFailure(s"Expected $tpe but got ${reader.jsonType}: ${reader.currentValue}") + private def expectedError(tpe: JsonType) = + throw new ReadFailure(s"Expected $tpe but got ${reader.jsonType}: ${reader.currentValue}") private def checkedValue[T](jsonType: JsonType): T = { if (reader.jsonType != jsonType) expectedError(jsonType) @@ -64,6 +65,8 @@ class JsonStringInput(reader: JsonReader, callback: AfterElement = AfterElementN def readInt(): Int = matchNumericString(_.toInt) def readLong(): Long = matchNumericString(_.toLong) def readDouble(): Double = matchNumericString(_.toDouble) + def readBigInteger(): JBigInteger = matchNumericString(new JBigInteger(_)) + def readBigDecimal(): JBigDecimal = matchNumericString(new JBigDecimal(_)) def readBinary(): Array[Byte] = { val hex = checkedValue[String](JsonType.string) val result = new Array[Byte](hex.length / 2) @@ -213,7 +216,7 @@ final class JsonReader(val json: String) { private def readHex(): Int = fromHex(read()) - private def parseNumber(): Any = { + private def parseNumber(): String = { val start = i if (isNext('-')) { diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala index 4213c1a2c..56cce2866 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala @@ -50,6 +50,8 @@ final class JsonStringOutput(builder: JStringBuilder) extends BaseJsonOutput wit writeString(double.toString) else builder.append(double.toString) } + def writeBigInteger(bigInteger: JBigInteger): Unit = builder.append(bigInteger.toString) + def writeBigDecimal(bigDecimal: JBigDecimal): Unit = builder.append(bigDecimal.toString) def writeBinary(binary: Array[Byte]): Unit = { builder.append('"') var i = 0 diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala index 63a8f4502..4432838f1 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala @@ -22,9 +22,11 @@ case class FieldTypes( i: Int, j: Long, k: Double, - l: Array[Byte], - m: Obj, - n: List[List[Obj]] + l: BigInt, + m: BigDecimal, + n: Array[Byte], + o: Obj, + p: List[List[Obj]] ) class StreamInputOutputTest extends FunSuite { @@ -42,6 +44,8 @@ class StreamInputOutputTest extends FunSuite { -5, -6, -7.3, + BigInt("5345224654563123434325343"), + BigDecimal(BigInt("2356342454564522135435"), 150), Array[Byte](1, 2, 4, 2), Obj(10, "x"), List( @@ -90,8 +94,8 @@ class StreamInputOutputTest extends FunSuite { test("encode and decode all field types in a complicated structure") { val encoded = encDec(fieldTypesInstance) - assert(fieldTypesInstance.l sameElements encoded.l) - assert(fieldTypesInstance == encoded.copy(l = fieldTypesInstance.l)) + assert(fieldTypesInstance.n sameElements encoded.n) + assert(fieldTypesInstance == encoded.copy(n = fieldTypesInstance.n)) } test("raw API usage") { diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala index de1f499bf..fca6b1ff2 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala @@ -199,6 +199,8 @@ class JsonStringInputOutputTest extends FunSuite with SerializationTestUtils wit deserialized.i2.long shouldBe item.i2.long deserialized.i2.float shouldBe item.i2.float deserialized.i2.double shouldBe item.i2.double + deserialized.i2.bigInt shouldBe item.i2.bigInt + deserialized.i2.bigDecimal shouldBe item.i2.bigDecimal deserialized.i2.binary shouldBe item.i2.binary deserialized.i2.list shouldBe item.i2.list deserialized.i2.set shouldBe item.i2.set @@ -207,7 +209,6 @@ class JsonStringInputOutputTest extends FunSuite with SerializationTestUtils wit } } - test("serialize and deserialize huge case classes") { implicit val arbTree: Arbitrary[DeepNestedTestCC] = Arbitrary { diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala index 9ba346748..00a65f2f8 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala @@ -26,8 +26,8 @@ trait SerializationTestUtils { case class CompleteItem( unit: Unit, string: String, char: Char, boolean: Boolean, byte: Byte, short: Short, int: Int, - long: Long, float: Float, double: Double, binary: Array[Byte], list: List[String], - set: Set[String], obj: TestCC, map: Map[String, Int] + long: Long, float: Float, double: Double, bigInt: BigInt, bigDecimal: BigDecimal, + binary: Array[Byte], list: List[String], set: Set[String], obj: TestCC, map: Map[String, Int] ) object CompleteItem extends HasGenCodec[CompleteItem] { implicit val arb: Arbitrary[CompleteItem] = Arbitrary(for { @@ -41,11 +41,13 @@ trait SerializationTestUtils { l <- arbitrary[Long] f <- arbitrary[Float] d <- arbitrary[Double] + bi <- arbitrary[BigInt] + bd <- arbitrary[BigDecimal] binary <- arbitrary[Array[Byte]] list <- arbitrary[List[String]] set <- arbitrary[Set[String]] obj <- arbitrary[TestCC] map <- arbitrary[Map[String, Int]] - } yield CompleteItem(u, str, c, bool, b, s, i, l, f, d, binary, list, set, obj, map)) + } yield CompleteItem(u, str, c, bool, b, s, i, l, f, d, bi, bd, binary, list, set, obj, map)) } } diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala index dd8f34e8e..c1dbc802e 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala @@ -1,6 +1,8 @@ package com.avsystem.commons package mongo +import java.nio.ByteBuffer + import com.avsystem.commons.serialization.{Input, Output} import org.bson.types.ObjectId @@ -8,6 +10,24 @@ trait BsonInput extends Any with Input { def readObjectId(): ObjectId } +object BsonInput { + def bigDecimalFromBytes(bytes: Array[Byte]): JBigDecimal = { + val buf = ByteBuffer.wrap(bytes) + val unscaledBytes = new Array[Byte](bytes.length - 4) + buf.get(unscaledBytes) + val unscaled = new JBigInteger(unscaledBytes) + val scale = buf.getInt + new JBigDecimal(unscaled, scale) + } +} + trait BsonOutput extends Any with Output { def writeObjectId(objectId: ObjectId): Unit } + +object BsonOutput { + def bigDecimalBytes(bigDecimal: JBigDecimal): Array[Byte] = { + val unscaledBytes = bigDecimal.unscaledValue.toByteArray + ByteBuffer.allocate(unscaledBytes.length + 4).put(unscaledBytes).putInt(bigDecimal.scale).array + } +} diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala index 4c36887e6..e0a0bb6e0 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala @@ -24,6 +24,8 @@ class BsonReaderInput(br: BsonReader) extends BsonInput { override def readLong(): Long = br.readInt64() override def readTimestamp(): Long = br.readDateTime() override def readDouble(): Double = br.readDouble() + override def readBigInteger(): JBigInteger = new JBigInteger(br.readBinaryData().getData) + override def readBigDecimal(): JBigDecimal = BsonInput.bigDecimalFromBytes(br.readBinaryData().getData) override def readBinary(): Array[Byte] = br.readBinaryData().getData override def readList(): BsonReaderListInput = { br.readStartArray() diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala index c3ed12341..d567b3ad0 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala @@ -33,6 +33,8 @@ final class BsonValueOutput(receiver: BsonValue => Unit = _ => ()) extends BsonO override def writeLong(long: Long): Unit = setValue(new BsonInt64(long)) override def writeTimestamp(millis: Long): Unit = setValue(new BsonDateTime(millis)) override def writeDouble(double: Double): Unit = setValue(new BsonDouble(double)) + override def writeBigInteger(bigInteger: JBigInteger): Unit = setValue(new BsonBinary(bigInteger.toByteArray)) + override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = setValue(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = setValue(new BsonBinary(binary)) override def writeList(): ListOutput = new BsonValueListOutput(setValue) override def writeObject(): ObjectOutput = new BsonValueObjectOutput(setValue) diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala index 12e373d57..d77aeb28c 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala @@ -13,6 +13,10 @@ final class BsonWriterOutput(bw: BsonWriter) extends BsonOutput { override def writeLong(long: Long): Unit = bw.writeInt64(long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(millis) override def writeDouble(double: Double): Unit = bw.writeDouble(double) + override def writeBigInteger(bigInteger: JBigInteger): Unit = + bw.writeBinaryData(new BsonBinary(bigInteger.toByteArray)) + override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = + bw.writeBinaryData(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { bw.writeStartArray() @@ -33,6 +37,10 @@ final class BsonWriterNamedOutput(escapedName: String, bw: BsonWriter) extends B override def writeLong(long: Long): Unit = bw.writeInt64(escapedName, long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(escapedName, millis) override def writeDouble(double: Double): Unit = bw.writeDouble(escapedName, double) + override def writeBigInteger(bigInteger: JBigInteger): Unit = + bw.writeBinaryData(escapedName, new BsonBinary(bigInteger.toByteArray)) + override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = + bw.writeBinaryData(escapedName, new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(escapedName, new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { bw.writeStartArray(escapedName) diff --git a/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala new file mode 100644 index 000000000..eee242869 --- /dev/null +++ b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala @@ -0,0 +1,14 @@ +package com.avsystem.commons +package mongo + +import org.scalacheck.Arbitrary +import org.scalatest.FunSuite +import org.scalatest.prop.PropertyChecks + +class BigDecimalEncodingTest extends FunSuite with PropertyChecks { + test("BigDecimal BSON encoding") { + forAll(Arbitrary.arbitrary[BigDecimal].map(_.bigDecimal)) { value => + assert(value == BsonInput.bigDecimalFromBytes(BsonOutput.bigDecimalBytes(value))) + } + } +} From 2985995334baa9bcc94af0f5b60e799c8ec48150 Mon Sep 17 00:00:00 2001 From: ghik Date: Wed, 13 Jun 2018 14:12:53 +0200 Subject: [PATCH 2/4] using Scala BigInt/BigDecimal instead of JBigInteger/JBigDecimal --- .../commons/ser/GenCodecBenchmarks.scala | 2 +- .../commons/ser/CirceJsonInputOutput.scala | 8 +++--- .../commons/serialization/GenCodec.scala | 10 ++++--- .../commons/serialization/InputOutput.scala | 8 +++--- .../SimpleValueInputOutput.scala | 12 ++++---- .../serialization/StreamInputOutput.scala | 28 +++++++++---------- .../serialization/json/JsonStringInput.scala | 4 +-- .../serialization/json/JsonStringOutput.scala | 4 +-- .../commons/mongo/BsonInputOutput.scala | 10 +++---- .../commons/mongo/BsonReaderInput.scala | 4 +-- .../commons/mongo/BsonValueOutput.scala | 4 +-- .../commons/mongo/BsonWriterOutput.scala | 12 ++++---- .../mongo/BigDecimalEncodingTest.scala | 3 +- 13 files changed, 55 insertions(+), 54 deletions(-) diff --git a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala index 0c1fe453a..b6067969f 100644 --- a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala +++ b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala @@ -39,7 +39,7 @@ object DummyInput extends Input { def readList() = ignored def readBoolean() = ignored def readDouble() = ignored - def readBigInteger() = ignored + def readBigInt() = ignored def readBigDecimal() = ignored def skip() = () } diff --git a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala index 88206ede5..ca104edee 100644 --- a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala +++ b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala @@ -22,8 +22,8 @@ class CirceJsonOutput(consumer: Json => Any) extends Output { def writeInt(int: Int): Unit = consumer(Json.fromInt(int)) def writeLong(long: Long): Unit = consumer(Json.fromLong(long)) def writeDouble(double: Double): Unit = consumer(Json.fromDoubleOrString(double)) - def writeBigInteger(bigInteger: JBigInteger): Unit = consumer(Json.fromBigInt(BigInt(bigInteger))) - def writeBigDecimal(bigDecimal: JBigDecimal): Unit = consumer(Json.fromBigDecimal(BigDecimal(bigDecimal))) + def writeBigInt(bigInt: BigInt): Unit = consumer(Json.fromBigInt(bigInt)) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = consumer(Json.fromBigDecimal(bigDecimal)) def writeBinary(binary: Array[Byte]): Unit = consumer(Json.fromValues(binary.map(Json.fromInt(_)))) def writeList(): ListOutput = new CirceJsonListOutput(consumer) def writeObject(): ObjectOutput = new CirceJsonObjectOutput(consumer) @@ -69,8 +69,8 @@ class CirceJsonInput(json: Json) extends Input { def readInt(): Int = asNumber.toInt.getOrElse(failNot("int")) def readLong(): Long = asNumber.toLong.getOrElse(failNot("long")) def readDouble(): Double = asNumber.toDouble - def readBigInteger(): JBigInteger = asNumber.toBigInt.getOrElse(failNot("bigInteger")).bigInteger - def readBigDecimal(): JBigDecimal = asNumber.toBigDecimal.getOrElse(failNot("bigDecimal")).bigDecimal + def readBigInt(): BigInt = asNumber.toBigInt.getOrElse(failNot("bigInteger")) + def readBigDecimal(): BigDecimal = asNumber.toBigDecimal.getOrElse(failNot("bigDecimal")) def readBinary(): Array[Byte] = json.asArray.getOrElse(failNot("array")).iterator .map(_.asNumber.flatMap(_.toByte).getOrElse(failNot("byte"))).toArray def readList(): ListInput = new CirceJsonListInput(json.asArray.getOrElse(failNot("array"))) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala index 3ee707d1c..8257ca0ab 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala @@ -287,8 +287,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val LongCodec: GenCodec[Long] = create(_.readLong(), _ writeLong _) implicit lazy val FloatCodec: GenCodec[Float] = create(_.readFloat(), _ writeFloat _) implicit lazy val DoubleCodec: GenCodec[Double] = create(_.readDouble(), _ writeDouble _) - implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(i => BigInt(i.readBigInteger()), (o, v) => o.writeBigInteger(v.bigInteger)) - implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(i => BigDecimal(i.readBigDecimal()), (o, v) => o.writeBigDecimal(v.bigDecimal)) + implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(_.readBigInt(), _ writeBigInt _) + implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(_.readBigDecimal(), _ writeBigDecimal _) implicit lazy val JBooleanCodec: GenCodec[JBoolean] = createNullable(_.readBoolean(), _ writeBoolean _) implicit lazy val JCharacterCodec: GenCodec[JCharacter] = createNullable(_.readChar(), _ writeChar _) @@ -298,8 +298,10 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val JLongCodec: GenCodec[JLong] = createNullable(_.readLong(), _ writeLong _) implicit lazy val JFloatCodec: GenCodec[JFloat] = createNullable(_.readFloat(), _ writeFloat _) implicit lazy val JDoubleCodec: GenCodec[JDouble] = createNullable(_.readDouble(), _ writeDouble _) - implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = createNullable(_.readBigInteger(), _ writeBigInteger _) - implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = createNullable(_.readBigDecimal(), _ writeBigDecimal _) + implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = + createNullable(_.readBigInt().bigInteger, (o, v) => o.writeBigInt(BigInt(v))) + implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = + createNullable(_.readBigDecimal().bigDecimal, (o, v) => o.writeBigDecimal(BigDecimal(v))) implicit lazy val JDateCodec: GenCodec[JDate] = createNullable(i => new JDate(i.readTimestamp()), (o, d) => o.writeTimestamp(d.getTime)) implicit lazy val StringCodec: GenCodec[String] = createNullable(_.readString(), _ writeString _) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala index 55383d040..5c1d83b1a 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala @@ -22,8 +22,8 @@ trait Output extends Any { def writeTimestamp(millis: Long): Unit = writeLong(millis) def writeFloat(float: Float): Unit = writeDouble(float) def writeDouble(double: Double): Unit - def writeBigInteger(bigInteger: JBigInteger): Unit - def writeBigDecimal(bigDecimal: JBigDecimal): Unit + def writeBigInt(bigInt: BigInt): Unit + def writeBigDecimal(bigDecimal: BigDecimal): Unit def writeBinary(binary: Array[Byte]): Unit def writeList(): ListOutput def writeObject(): ObjectOutput @@ -141,8 +141,8 @@ trait Input extends Any { def readTimestamp(): Long = readLong() def readFloat(): Float = readDouble().toFloat def readDouble(): Double - def readBigInteger(): JBigInteger - def readBigDecimal(): JBigDecimal + def readBigInt(): BigInt + def readBigDecimal(): BigDecimal def readBinary(): Array[Byte] def readList(): ListInput def readObject(): ObjectInput diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala index 3d38ae6fd..3b0d48f84 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala @@ -28,8 +28,8 @@ object SimpleValueOutput { * - `Int` * - `Long` * - `Double` - * - `java.math.BigInteger` - * - `java.math.BigDecimal` + * - `BigInt` + * - `BigDecimal` * - `Boolean` * - `String` * - `Array[Byte]` @@ -57,8 +57,8 @@ class SimpleValueOutput( def writeInt(int: Int): Unit = consumer(int) def writeLong(long: Long): Unit = consumer(long) def writeDouble(double: Double): Unit = consumer(double) - def writeBigInteger(bigInteger: JBigInteger): Unit = consumer(bigInteger) - def writeBigDecimal(bigDecimal: JBigDecimal): Unit = consumer(bigDecimal) + def writeBigInt(bigInt: BigInt): Unit = consumer(bigInt) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = consumer(bigDecimal) def writeBinary(binary: Array[Byte]): Unit = consumer(binary) def writeList(): ListOutput = new ListOutput { @@ -106,8 +106,8 @@ class SimpleValueInput(value: Any) extends Input { def readInt(): Int = doReadUnboxed[Int, JInteger] def readLong(): Long = doReadUnboxed[Long, JLong] def readDouble(): Double = doReadUnboxed[Double, JDouble] - def readBigInteger(): JBigInteger = doRead[JBigInteger] - def readBigDecimal(): JBigDecimal = doRead[JBigDecimal] + def readBigInt(): BigInt = doRead[JBigInteger] + def readBigDecimal(): BigDecimal = doRead[JBigDecimal] def readBinary(): Array[Byte] = doRead[Array[Byte]] def readObject(): ObjectInput = diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala index b716715cc..80ad3059e 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala @@ -29,7 +29,7 @@ private object FormatConstants { final val ObjectStartMarker: Byte = 11 final val ListEndMarker: Byte = 12 final val ObjectEndMarker: Byte = 13 - final val BigIntegerMarker: Byte = 14 + final val BitIntMarker: Byte = 14 final val BigDecimalMarker: Byte = 15 } @@ -79,21 +79,21 @@ class StreamInput(is: DataInputStream) extends Input { else throw new ReadFailure(s"Expected double but $markerByte found") - def readBigInteger(): JBigInteger = if (markerByte == BigIntegerMarker) { + def readBigInt(): BigInt = if (markerByte == BitIntMarker) { val len = is.readInt() val bytes = new Array[Byte](len) is.read(bytes) - new JBigInteger(bytes) + BigInt(bytes) } else throw new ReadFailure(s"Expected big integer but $markerByte found") - def readBigDecimal(): JBigDecimal = if (markerByte == BigDecimalMarker) { + def readBigDecimal(): BigDecimal = if (markerByte == BigDecimalMarker) { val len = is.readInt() val bytes = new Array[Byte](len) is.read(bytes) - val unscaled = new JBigInteger(bytes) + val unscaled = BigInt(bytes) val scale = is.readInt() - new JBigDecimal(unscaled, scale) + BigDecimal(unscaled, scale) } else throw new ReadFailure(s"Expected big decimal but $markerByte found") @@ -141,7 +141,7 @@ class StreamInput(is: DataInputStream) extends Input { new StreamListInput(is).skipRemaining() case ObjectStartMarker => new StreamObjectInput(is).skipRemaining() - case BigIntegerMarker => + case BitIntMarker => is.skipBytes(is.readInt()) case BigDecimalMarker => is.skipBytes(is.readInt() + 4) @@ -214,8 +214,8 @@ private object StreamObjectInput { def readInt(): Int = nope def readLong(): Long = nope def readDouble(): Double = nope - def readBigInteger(): JBigInteger = nope - def readBigDecimal(): JBigDecimal = nope + def readBigInt(): BigInt = nope + def readBigDecimal(): BigDecimal = nope def readBinary(): Array[Byte] = nope def readList(): ListInput = nope def readObject(): ObjectInput = nope @@ -258,16 +258,16 @@ class StreamOutput(os: DataOutputStream) extends Output { os.writeDouble(double) } - def writeBigInteger(bigInteger: JBigInteger): Unit = { - os.writeByte(BigIntegerMarker) - val bytes = bigInteger.toByteArray + def writeBigInt(bigInt: BigInt): Unit = { + os.writeByte(BitIntMarker) + val bytes = bigInt.toByteArray os.writeInt(bytes.length) os.write(bytes) } - def writeBigDecimal(bigDecimal: JBigDecimal): Unit = { + def writeBigDecimal(bigDecimal: BigDecimal): Unit = { os.writeByte(BigDecimalMarker) - val bytes = bigDecimal.unscaledValue.toByteArray + val bytes = bigDecimal.bigDecimal.unscaledValue.toByteArray os.writeInt(bytes.length) os.write(bytes) os.writeInt(bigDecimal.scale) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala index d942997d7..f6f813c6e 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala @@ -65,8 +65,8 @@ class JsonStringInput(reader: JsonReader, callback: AfterElement = AfterElementN def readInt(): Int = matchNumericString(_.toInt) def readLong(): Long = matchNumericString(_.toLong) def readDouble(): Double = matchNumericString(_.toDouble) - def readBigInteger(): JBigInteger = matchNumericString(new JBigInteger(_)) - def readBigDecimal(): JBigDecimal = matchNumericString(new JBigDecimal(_)) + def readBigInt(): BigInt = matchNumericString(BigInt(_)) + def readBigDecimal(): BigDecimal = matchNumericString(BigDecimal(_)) def readBinary(): Array[Byte] = { val hex = checkedValue[String](JsonType.string) val result = new Array[Byte](hex.length / 2) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala index 56cce2866..32c2470eb 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala @@ -50,8 +50,8 @@ final class JsonStringOutput(builder: JStringBuilder) extends BaseJsonOutput wit writeString(double.toString) else builder.append(double.toString) } - def writeBigInteger(bigInteger: JBigInteger): Unit = builder.append(bigInteger.toString) - def writeBigDecimal(bigDecimal: JBigDecimal): Unit = builder.append(bigDecimal.toString) + def writeBigInt(bigInt: BigInt): Unit = builder.append(bigInt.toString) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = builder.append(bigDecimal.toString) def writeBinary(binary: Array[Byte]): Unit = { builder.append('"') var i = 0 diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala index c1dbc802e..362c0aef3 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala @@ -11,13 +11,13 @@ trait BsonInput extends Any with Input { } object BsonInput { - def bigDecimalFromBytes(bytes: Array[Byte]): JBigDecimal = { + def bigDecimalFromBytes(bytes: Array[Byte]): BigDecimal = { val buf = ByteBuffer.wrap(bytes) val unscaledBytes = new Array[Byte](bytes.length - 4) buf.get(unscaledBytes) - val unscaled = new JBigInteger(unscaledBytes) + val unscaled = BigInt(unscaledBytes) val scale = buf.getInt - new JBigDecimal(unscaled, scale) + BigDecimal(unscaled, scale) } } @@ -26,8 +26,8 @@ trait BsonOutput extends Any with Output { } object BsonOutput { - def bigDecimalBytes(bigDecimal: JBigDecimal): Array[Byte] = { - val unscaledBytes = bigDecimal.unscaledValue.toByteArray + def bigDecimalBytes(bigDecimal: BigDecimal): Array[Byte] = { + val unscaledBytes = bigDecimal.bigDecimal.unscaledValue.toByteArray ByteBuffer.allocate(unscaledBytes.length + 4).put(unscaledBytes).putInt(bigDecimal.scale).array } } diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala index e0a0bb6e0..77b256014 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala @@ -24,8 +24,8 @@ class BsonReaderInput(br: BsonReader) extends BsonInput { override def readLong(): Long = br.readInt64() override def readTimestamp(): Long = br.readDateTime() override def readDouble(): Double = br.readDouble() - override def readBigInteger(): JBigInteger = new JBigInteger(br.readBinaryData().getData) - override def readBigDecimal(): JBigDecimal = BsonInput.bigDecimalFromBytes(br.readBinaryData().getData) + override def readBigInt(): BigInt = BigInt(br.readBinaryData().getData) + override def readBigDecimal(): BigDecimal = BsonInput.bigDecimalFromBytes(br.readBinaryData().getData) override def readBinary(): Array[Byte] = br.readBinaryData().getData override def readList(): BsonReaderListInput = { br.readStartArray() diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala index d567b3ad0..c9f59531f 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala @@ -33,8 +33,8 @@ final class BsonValueOutput(receiver: BsonValue => Unit = _ => ()) extends BsonO override def writeLong(long: Long): Unit = setValue(new BsonInt64(long)) override def writeTimestamp(millis: Long): Unit = setValue(new BsonDateTime(millis)) override def writeDouble(double: Double): Unit = setValue(new BsonDouble(double)) - override def writeBigInteger(bigInteger: JBigInteger): Unit = setValue(new BsonBinary(bigInteger.toByteArray)) - override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = setValue(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) + override def writeBigInt(bigInt: BigInt): Unit = setValue(new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = setValue(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = setValue(new BsonBinary(binary)) override def writeList(): ListOutput = new BsonValueListOutput(setValue) override def writeObject(): ObjectOutput = new BsonValueObjectOutput(setValue) diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala index d77aeb28c..5490c9e05 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala @@ -13,9 +13,9 @@ final class BsonWriterOutput(bw: BsonWriter) extends BsonOutput { override def writeLong(long: Long): Unit = bw.writeInt64(long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(millis) override def writeDouble(double: Double): Unit = bw.writeDouble(double) - override def writeBigInteger(bigInteger: JBigInteger): Unit = - bw.writeBinaryData(new BsonBinary(bigInteger.toByteArray)) - override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = + override def writeBigInt(bigInt: BigInt): Unit = + bw.writeBinaryData(new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = bw.writeBinaryData(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { @@ -37,9 +37,9 @@ final class BsonWriterNamedOutput(escapedName: String, bw: BsonWriter) extends B override def writeLong(long: Long): Unit = bw.writeInt64(escapedName, long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(escapedName, millis) override def writeDouble(double: Double): Unit = bw.writeDouble(escapedName, double) - override def writeBigInteger(bigInteger: JBigInteger): Unit = - bw.writeBinaryData(escapedName, new BsonBinary(bigInteger.toByteArray)) - override def writeBigDecimal(bigDecimal: JBigDecimal): Unit = + override def writeBigInt(bigInt: BigInt): Unit = + bw.writeBinaryData(escapedName, new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = bw.writeBinaryData(escapedName, new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(escapedName, new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { diff --git a/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala index eee242869..a77014a1e 100644 --- a/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala +++ b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala @@ -1,13 +1,12 @@ package com.avsystem.commons package mongo -import org.scalacheck.Arbitrary import org.scalatest.FunSuite import org.scalatest.prop.PropertyChecks class BigDecimalEncodingTest extends FunSuite with PropertyChecks { test("BigDecimal BSON encoding") { - forAll(Arbitrary.arbitrary[BigDecimal].map(_.bigDecimal)) { value => + forAll { value: BigDecimal => assert(value == BsonInput.bigDecimalFromBytes(BsonOutput.bigDecimalBytes(value))) } } From 2b6c6c8be2ed9b59af8560297d78359ef4914c05 Mon Sep 17 00:00:00 2001 From: ghik Date: Wed, 13 Jun 2018 15:23:09 +0200 Subject: [PATCH 3/4] reformat of ifs --- .../serialization/StreamInputOutput.scala | 123 +++++++++--------- 1 file changed, 60 insertions(+), 63 deletions(-) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala index 80ad3059e..f92bbbd0f 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala @@ -49,71 +49,68 @@ class StreamInput(is: DataInputStream) extends Input { InputType.Simple } - def readNull(): Null = if (markerByte == NullMarker) - null - else - throw new ReadFailure(s"Expected null but $markerByte found") - - def readString(): String = if (markerByte == StringMarker) - is.readUTF() - else - throw new ReadFailure(s"Expected string but $markerByte found") - - def readBoolean(): Boolean = if (markerByte == BooleanMarker) - is.readBoolean() - else - throw new ReadFailure(s"Expected boolean but $markerByte found") - - def readInt(): Int = if (markerByte == IntMarker) - is.readInt() - else - throw new ReadFailure(s"Expected int but $markerByte found") - - def readLong(): Long = if (markerByte == LongMarker) - is.readLong() - else - throw new ReadFailure(s"Expected long but $markerByte found") - - def readDouble(): Double = if (markerByte == DoubleMarker) - is.readDouble() - else - throw new ReadFailure(s"Expected double but $markerByte found") - - def readBigInt(): BigInt = if (markerByte == BitIntMarker) { - val len = is.readInt() - val bytes = new Array[Byte](len) - is.read(bytes) - BigInt(bytes) - } else - throw new ReadFailure(s"Expected big integer but $markerByte found") - - def readBigDecimal(): BigDecimal = if (markerByte == BigDecimalMarker) { - val len = is.readInt() - val bytes = new Array[Byte](len) - is.read(bytes) - val unscaled = BigInt(bytes) - val scale = is.readInt() - BigDecimal(unscaled, scale) - } else - throw new ReadFailure(s"Expected big decimal but $markerByte found") - - def readBinary(): Array[Byte] = if (markerByte == ByteArrayMarker) { - val binary = Array.ofDim[Byte](is.readInt()) - is.readFully(binary) - binary - } else { - throw new ReadFailure(s"Expected binary array but $markerByte found") - } + def readNull(): Null = + if (markerByte == NullMarker) null + else throw new ReadFailure(s"Expected null but $markerByte found") + + def readString(): String = + if (markerByte == StringMarker) is.readUTF() + else throw new ReadFailure(s"Expected string but $markerByte found") + + def readBoolean(): Boolean = + if (markerByte == BooleanMarker) is.readBoolean() + else throw new ReadFailure(s"Expected boolean but $markerByte found") + + def readInt(): Int = + if (markerByte == IntMarker) is.readInt() + else throw new ReadFailure(s"Expected int but $markerByte found") + + def readLong(): Long = + if (markerByte == LongMarker) is.readLong() + else throw new ReadFailure(s"Expected long but $markerByte found") + + def readDouble(): Double = + if (markerByte == DoubleMarker) is.readDouble() + else throw new ReadFailure(s"Expected double but $markerByte found") + + def readBigInt(): BigInt = + if (markerByte == BitIntMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + BigInt(bytes) + } else { + throw new ReadFailure(s"Expected big integer but $markerByte found") + } + + def readBigDecimal(): BigDecimal = + if (markerByte == BigDecimalMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + val unscaled = BigInt(bytes) + val scale = is.readInt() + BigDecimal(unscaled, scale) + } else { + throw new ReadFailure(s"Expected big decimal but $markerByte found") + } + + def readBinary(): Array[Byte] = + if (markerByte == ByteArrayMarker) { + val binary = Array.ofDim[Byte](is.readInt()) + is.readFully(binary) + binary + } else { + throw new ReadFailure(s"Expected binary array but $markerByte found") + } - def readList(): ListInput = if (markerByte == ListStartMarker) - new StreamListInput(is) - else - throw new ReadFailure(s"Expected list but $markerByte found") + def readList(): ListInput = + if (markerByte == ListStartMarker) new StreamListInput(is) + else throw new ReadFailure(s"Expected list but $markerByte found") - def readObject(): ObjectInput = if (markerByte == ObjectStartMarker) - new StreamObjectInput(is) - else - throw new ReadFailure(s"Expected object but $markerByte found") + def readObject(): ObjectInput = + if (markerByte == ObjectStartMarker) new StreamObjectInput(is) + else throw new ReadFailure(s"Expected object but $markerByte found") def skip(): Unit = markerByte match { case NullMarker => From 26046bb722a250f3c07103f303b68d7079c0d489 Mon Sep 17 00:00:00 2001 From: ghik Date: Wed, 13 Jun 2018 15:28:48 +0200 Subject: [PATCH 4/4] using Integer.BYTES instead of 4 --- .../avsystem/commons/serialization/StreamInputOutput.scala | 2 +- .../scala/com/avsystem/commons/mongo/BsonInputOutput.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala index f92bbbd0f..86a2f5519 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala @@ -141,7 +141,7 @@ class StreamInput(is: DataInputStream) extends Input { case BitIntMarker => is.skipBytes(is.readInt()) case BigDecimalMarker => - is.skipBytes(is.readInt() + 4) + is.skipBytes(is.readInt() + Integer.BYTES) case unexpected => throw new ReadFailure(s"Unexpected marker byte: $unexpected") } diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala index 362c0aef3..ed9c41768 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala @@ -13,7 +13,7 @@ trait BsonInput extends Any with Input { object BsonInput { def bigDecimalFromBytes(bytes: Array[Byte]): BigDecimal = { val buf = ByteBuffer.wrap(bytes) - val unscaledBytes = new Array[Byte](bytes.length - 4) + val unscaledBytes = new Array[Byte](bytes.length - Integer.BYTES) buf.get(unscaledBytes) val unscaled = BigInt(unscaledBytes) val scale = buf.getInt @@ -28,6 +28,6 @@ trait BsonOutput extends Any with Output { object BsonOutput { def bigDecimalBytes(bigDecimal: BigDecimal): Array[Byte] = { val unscaledBytes = bigDecimal.bigDecimal.unscaledValue.toByteArray - ByteBuffer.allocate(unscaledBytes.length + 4).put(unscaledBytes).putInt(bigDecimal.scale).array + ByteBuffer.allocate(unscaledBytes.length + Integer.BYTES).put(unscaledBytes).putInt(bigDecimal.scale).array } }