From 9e31cdf47ca4a0802c4e0d5ad3fd473a7cb8d1bc Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 17 Jul 2015 07:33:47 +0800 Subject: [PATCH 1/2] Supports decimals with precision > 18 for Parquet --- .../sql/parquet/CatalystRowConverter.scala | 25 +++--- .../sql/parquet/CatalystSchemaConverter.scala | 50 +++++++----- .../sql/parquet/ParquetTableSupport.scala | 79 +++++++++++++++---- .../spark/sql/parquet/ParquetIOSuite.scala | 6 +- 4 files changed, 113 insertions(+), 47 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index b5e4263008f5..e00bd90edb3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ @@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - var unscaled = 0L - var i = 0 + if (precision <= 8) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + var unscaled = 0L + var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - Decimal(unscaled, precision, scale) + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index e9ef01e2dba1..b116e04081d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -387,23 +387,21 @@ private[parquet] class CatalystSchemaConverter( // ===================================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and - // always store decimals in fixed-length byte arrays. - case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType() if !followParquetFormatSpec => + case dec @ DecimalType.Unlimited if !followParquetFormatSpec => throw new AnalysisException( - s"Data type $dec is not supported. " + - s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + - "decimal precision and scale must be specified, " + - "and precision must be less than or equal to 18.") + s"Data type $dec is not supported. Decimal precision must be specified.") // ===================================== // Decimals (follow Parquet format spec) @@ -436,9 +434,13 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) + case dec @ DecimalType.Unlimited if followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. Decimal precision must be specified.") + // =================================================== // ArrayType and MapType (for Spark versions <= 1.4.x) // =================================================== @@ -548,15 +550,6 @@ private[parquet] class CatalystSchemaConverter( Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes .asInstanceOf[Int] } - - // Min byte counts needed to store decimals with various precisions - private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } } @@ -580,4 +573,23 @@ private[parquet] object CatalystSchemaConverter { throw new AnalysisException(message) } } + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + + private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + def minBytesForPrecision(precision : Int) : Int = { + if (precision < MIN_BYTES_FOR_PRECISION.length) { + MIN_BYTES_FOR_PRECISION(precision) + } else { + computeMinBytesForPrecision(precision) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fc9f61a63676..dd16f3a60008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} import java.util.{HashMap => JHashMap} @@ -114,11 +115,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(value.asInstanceOf[Decimal], d.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(value.asInstanceOf[Decimal], precision) + case d @ DecimalType.Unlimited => + sys.error(s"Unsupported data type $d, cannot write to consumer") case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -200,21 +200,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } // Scratch array used to write decimals as fixed-length binary - private[this] val scratchBytes = new Array[Byte](8) + private[this] var reusableDecimalBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 + val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) + + def longToBinary(unscaled: Long): Binary = { + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + reusableDecimalBytes(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + } + + def bigIntegerToBinary(unscaled: BigInteger): Binary = { + unscaled.toByteArray match { + case bytes if bytes.length == numBytes => + Binary.fromByteArray(bytes) + + case bytes if bytes.length <= reusableDecimalBytes.length => + val signedByte = (if (bytes.head < 0) -1 else 0).toByte + util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + + case bytes => + reusableDecimalBytes = new Array[Byte](bytes.length) + bigIntegerToBinary(unscaled) + } + } + + val binary = if (numBytes <= 8) { + longToBinary(decimal.toUnscaledLong) + } else { + bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) - } + writer.addBinary(binary) + } // array used to write Timestamp as Int96 (fixed-length binary) private[this] val int96buf = new Array[Byte](12) @@ -264,6 +290,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case TimestampType => writeTimestamp(record.getLong(index)) case FloatType => writer.addFloat(record.getFloat(index)) case DoubleType => writer.addDouble(record.getDouble(index)) +<<<<<<< HEAD case StringType => writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => @@ -273,6 +300,26 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { sys.error(s"Unsupported datatype $d, cannot write to consumer") } writeDecimal(record.getDecimal(index), d.precision) +||||||| merged common ancestors + case StringType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) +======= + case StringType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) + case DecimalType.Fixed(precision, _) => + writeDecimal(record(index).asInstanceOf[Decimal], precision) + case d @ DecimalType.Unlimited => + sys.error(s"Unsupported data type $d, cannot write to consumer") +>>>>>>> Supports decimals with precision > 18 for Parquet case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b5314a3dd92e..c740dc98825d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -106,7 +106,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) @@ -114,10 +114,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - // Decimals with precision above 18 are not yet supported + // Unlimited-length decimals are not yet supported intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) sqlContext.read.parquet(dir.getCanonicalPath).collect() } } From a543d102eeb54b1005cdc06ba74cae05723a14e2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 17 Jul 2015 17:24:08 +0800 Subject: [PATCH 2/2] Fixes errors introduced while rebasing --- .../sql/parquet/CatalystSchemaConverter.scala | 8 ----- .../sql/parquet/ParquetTableSupport.scala | 33 +++---------------- .../spark/sql/parquet/ParquetIOSuite.scala | 8 ----- 3 files changed, 4 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index b116e04081d4..d43ca95b4eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -399,10 +399,6 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType.Unlimited if !followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. Decimal precision must be specified.") - // ===================================== // Decimals (follow Parquet format spec) // ===================================== @@ -437,10 +433,6 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType.Unlimited if followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. Decimal precision must be specified.") - // =================================================== // ArrayType and MapType (for Spark versions <= 1.4.x) // =================================================== diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index dd16f3a60008..78ecfad1d57c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -117,8 +117,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case DecimalType.Fixed(precision, _) => writeDecimal(value.asInstanceOf[Decimal], precision) - case d @ DecimalType.Unlimited => - sys.error(s"Unsupported data type $d, cannot write to consumer") case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -199,7 +197,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.endGroup() } - // Scratch array used to write decimals as fixed-length binary + // Scratch array used to write decimals as fixed-length byte array private[this] var reusableDecimalBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { @@ -223,7 +221,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case bytes if bytes.length <= reusableDecimalBytes.length => val signedByte = (if (bytes.head < 0) -1 else 0).toByte - util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) @@ -241,6 +239,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.addBinary(binary) } + // array used to write Timestamp as Int96 (fixed-length binary) private[this] val int96buf = new Array[Byte](12) @@ -290,36 +289,12 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case TimestampType => writeTimestamp(record.getLong(index)) case FloatType => writer.addFloat(record.getFloat(index)) case DoubleType => writer.addDouble(record.getDouble(index)) -<<<<<<< HEAD case StringType => writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record.getDecimal(index), d.precision) -||||||| merged common ancestors - case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) -======= - case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case DecimalType.Fixed(precision, _) => - writeDecimal(record(index).asInstanceOf[Decimal], precision) - case d @ DecimalType.Unlimited => - sys.error(s"Unsupported data type $d, cannot write to consumer") ->>>>>>> Supports decimals with precision > 18 for Parquet + writeDecimal(record.getDecimal(index), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index c740dc98825d..b415da5b8c13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -113,14 +113,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest { checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } - - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") {