Skip to content

Commit aa19c69

Browse files
rtrefferliancheng
authored andcommitted
[SPARK-4176] [SQL] Supports decimal types with precision > 18 in Parquet
This PR is based on #6796 authored by rtreffer. To support large decimal precisions (> 18), we do the following things in this PR: 1. Making `CatalystSchemaConverter` support large decimal precision Decimal types with large precision are always converted to fixed-length byte array. 2. Making `CatalystRowConverter` support reading decimal values with large precision When the precision is > 18, constructs `Decimal` values with an unscaled `BigInteger` rather than an unscaled `Long`. 3. Making `RowWriteSupport` support writing decimal values with large precision In this PR we always write decimals as fixed-length byte array, because Parquet write path hasn't been refactored to conform Parquet format spec (see SPARK-6774 & SPARK-8848). Two follow-up tasks should be done in future PRs: - [ ] Writing decimals as `INT32`, `INT64` when possible while fixing SPARK-8848 - [ ] Adding compatibility tests as part of SPARK-5463 Author: Cheng Lian <lian@databricks.com> Closes #7455 from liancheng/spark-4176 and squashes the following commits: a543d10 [Cheng Lian] Fixes errors introduced while rebasing 9e31cdf [Cheng Lian] Supports decimals with precision > 18 for Parquet
1 parent 6228381 commit aa19c69

File tree

4 files changed

+85
-60
lines changed

4 files changed

+85
-60
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20+
import java.math.{BigDecimal, BigInteger}
2021
import java.nio.ByteOrder
2122

2223
import scala.collection.JavaConversions._
@@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter(
263264
val scale = decimalType.scale
264265
val bytes = value.getBytes
265266

266-
var unscaled = 0L
267-
var i = 0
267+
if (precision <= 8) {
268+
// Constructs a `Decimal` with an unscaled `Long` value if possible.
269+
var unscaled = 0L
270+
var i = 0
268271

269-
while (i < bytes.length) {
270-
unscaled = (unscaled << 8) | (bytes(i) & 0xff)
271-
i += 1
272-
}
272+
while (i < bytes.length) {
273+
unscaled = (unscaled << 8) | (bytes(i) & 0xff)
274+
i += 1
275+
}
273276

274-
val bits = 8 * bytes.length
275-
unscaled = (unscaled << (64 - bits)) >> (64 - bits)
276-
Decimal(unscaled, precision, scale)
277+
val bits = 8 * bytes.length
278+
unscaled = (unscaled << (64 - bits)) >> (64 - bits)
279+
Decimal(unscaled, precision, scale)
280+
} else {
281+
// Otherwise, resorts to an unscaled `BigInteger` instead.
282+
Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale)
283+
}
277284
}
278285
}
279286

sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter(
387387
// =====================================
388388

389389
// Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
390-
// always store decimals in fixed-length byte arrays.
391-
case DecimalType.Fixed(precision, scale)
392-
if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec =>
390+
// always store decimals in fixed-length byte arrays. To keep compatibility with these older
391+
// versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated
392+
// by `DECIMAL`.
393+
case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec =>
393394
Types
394395
.primitive(FIXED_LEN_BYTE_ARRAY, repetition)
395396
.as(DECIMAL)
396397
.precision(precision)
397398
.scale(scale)
398-
.length(minBytesForPrecision(precision))
399+
.length(CatalystSchemaConverter.minBytesForPrecision(precision))
399400
.named(field.name)
400401

401-
case dec @ DecimalType() if !followParquetFormatSpec =>
402-
throw new AnalysisException(
403-
s"Data type $dec is not supported. " +
404-
s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," +
405-
"decimal precision and scale must be specified, " +
406-
"and precision must be less than or equal to 18.")
407-
408402
// =====================================
409403
// Decimals (follow Parquet format spec)
410404
// =====================================
@@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter(
436430
.as(DECIMAL)
437431
.precision(precision)
438432
.scale(scale)
439-
.length(minBytesForPrecision(precision))
433+
.length(CatalystSchemaConverter.minBytesForPrecision(precision))
440434
.named(field.name)
441435

442436
// ===================================================
@@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter(
548542
Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
549543
.asInstanceOf[Int]
550544
}
551-
552-
// Min byte counts needed to store decimals with various precisions
553-
private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision =>
554-
var numBytes = 1
555-
while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
556-
numBytes += 1
557-
}
558-
numBytes
559-
}
560545
}
561546

562547

@@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter {
580565
throw new AnalysisException(message)
581566
}
582567
}
568+
569+
private def computeMinBytesForPrecision(precision : Int) : Int = {
570+
var numBytes = 1
571+
while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
572+
numBytes += 1
573+
}
574+
numBytes
575+
}
576+
577+
private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision)
578+
579+
// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
580+
def minBytesForPrecision(precision : Int) : Int = {
581+
if (precision < MIN_BYTES_FOR_PRECISION.length) {
582+
MIN_BYTES_FOR_PRECISION(precision)
583+
} else {
584+
computeMinBytesForPrecision(precision)
585+
}
586+
}
583587
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20+
import java.math.BigInteger
2021
import java.nio.{ByteBuffer, ByteOrder}
2122
import java.util.{HashMap => JHashMap}
2223

@@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
114115
Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
115116
case BinaryType => writer.addBinary(
116117
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
117-
case d: DecimalType =>
118-
if (d.precision > 18) {
119-
sys.error(s"Unsupported datatype $d, cannot write to consumer")
120-
}
121-
writeDecimal(value.asInstanceOf[Decimal], d.precision)
118+
case DecimalType.Fixed(precision, _) =>
119+
writeDecimal(value.asInstanceOf[Decimal], precision)
122120
case _ => sys.error(s"Do not know how to writer $schema to consumer")
123121
}
124122
}
@@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
199197
writer.endGroup()
200198
}
201199

202-
// Scratch array used to write decimals as fixed-length binary
203-
private[this] val scratchBytes = new Array[Byte](8)
200+
// Scratch array used to write decimals as fixed-length byte array
201+
private[this] var reusableDecimalBytes = new Array[Byte](16)
204202

205203
private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
206-
val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
207-
val unscaledLong = decimal.toUnscaledLong
208-
var i = 0
209-
var shift = 8 * (numBytes - 1)
210-
while (i < numBytes) {
211-
scratchBytes(i) = (unscaledLong >> shift).toByte
212-
i += 1
213-
shift -= 8
204+
val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision)
205+
206+
def longToBinary(unscaled: Long): Binary = {
207+
var i = 0
208+
var shift = 8 * (numBytes - 1)
209+
while (i < numBytes) {
210+
reusableDecimalBytes(i) = (unscaled >> shift).toByte
211+
i += 1
212+
shift -= 8
213+
}
214+
Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
214215
}
215-
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
216+
217+
def bigIntegerToBinary(unscaled: BigInteger): Binary = {
218+
unscaled.toByteArray match {
219+
case bytes if bytes.length == numBytes =>
220+
Binary.fromByteArray(bytes)
221+
222+
case bytes if bytes.length <= reusableDecimalBytes.length =>
223+
val signedByte = (if (bytes.head < 0) -1 else 0).toByte
224+
java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte)
225+
System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length)
226+
Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
227+
228+
case bytes =>
229+
reusableDecimalBytes = new Array[Byte](bytes.length)
230+
bigIntegerToBinary(unscaled)
231+
}
232+
}
233+
234+
val binary = if (numBytes <= 8) {
235+
longToBinary(decimal.toUnscaledLong)
236+
} else {
237+
bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue())
238+
}
239+
240+
writer.addBinary(binary)
216241
}
217242

218243
// array used to write Timestamp as Int96 (fixed-length binary)
@@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
268293
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
269294
case BinaryType =>
270295
writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
271-
case d: DecimalType =>
272-
if (d.precision > 18) {
273-
sys.error(s"Unsupported datatype $d, cannot write to consumer")
274-
}
275-
writeDecimal(record.getDecimal(index), d.precision)
296+
case DecimalType.Fixed(precision, _) =>
297+
writeDecimal(record.getDecimal(index), precision)
276298
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
277299
}
278300
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
106106
// Parquet doesn't allow column names with spaces, have to add an alias here
107107
.select($"_1" cast decimal as "dec")
108108

109-
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
109+
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) {
110110
withTempPath { dir =>
111111
val data = makeDecimalRDD(DecimalType(precision, scale))
112112
data.write.parquet(dir.getCanonicalPath)
113113
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
114114
}
115115
}
116-
117-
// Decimals with precision above 18 are not yet supported
118-
intercept[Throwable] {
119-
withTempPath { dir =>
120-
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
121-
sqlContext.read.parquet(dir.getCanonicalPath).collect()
122-
}
123-
}
124116
}
125117

126118
test("date type") {

0 commit comments

Comments
 (0)