diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java index 7aa6b079073c4..ea15607bb8d15 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java @@ -21,16 +21,16 @@ public final class ParquetDictionary implements Dictionary { private org.apache.parquet.column.Dictionary dictionary; - private boolean castLongToInt = false; + private boolean needTransform = false; - public ParquetDictionary(org.apache.parquet.column.Dictionary dictionary, boolean castLongToInt) { + public ParquetDictionary(org.apache.parquet.column.Dictionary dictionary, boolean needTransform) { this.dictionary = dictionary; - this.castLongToInt = castLongToInt; + this.needTransform = needTransform; } @Override public int decodeToInt(int id) { - if (castLongToInt) { + if (needTransform) { return (int) dictionary.decodeToLong(id); } else { return dictionary.decodeToInt(id); @@ -39,7 +39,14 @@ public int decodeToInt(int id) { @Override public long decodeToLong(int id) { - return dictionary.decodeToLong(id); + if (needTransform) { + // For unsigned int32, it stores as dictionary encoded signed int32 in Parquet + // whenever dictionary is available. + // Here we lazily decode it to the original signed int value then convert to long(unit32). + return Integer.toUnsignedLong(dictionary.decodeToInt(id)); + } else { + return dictionary.decodeToLong(id); + } } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index dac18b1abe047..e091cbac41f75 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -46,7 +46,6 @@ import org.apache.spark.sql.types.DecimalType; import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator; @@ -279,16 +278,20 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). PrimitiveType primitiveType = descriptor.getPrimitiveType(); - if (primitiveType.getOriginalType() == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() <= Decimal.MAX_INT_DIGITS() && - primitiveType.getPrimitiveTypeName() == INT64) { - // We need to make sure that we initialize the right type for the dictionary otherwise - // WritableColumnVector will throw an exception when trying to decode to an Int when the - // dictionary is in fact initialized as Long - column.setDictionary(new ParquetDictionary(dictionary, true)); - } else { - column.setDictionary(new ParquetDictionary(dictionary, false)); - } + + // We need to make sure that we initialize the right type for the dictionary otherwise + // WritableColumnVector will throw an exception when trying to decode to an Int when the + // dictionary is in fact initialized as Long + boolean castLongToInt = primitiveType.getOriginalType() == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() <= Decimal.MAX_INT_DIGITS() && + primitiveType.getPrimitiveTypeName() == INT64; + + // We require a long value, but we need to use dictionary to decode the original + // signed int first + boolean isUnsignedInt32 = primitiveType.getOriginalType() == OriginalType.UINT_32; + + column.setDictionary( + new ParquetDictionary(dictionary, castLongToInt || isUnsignedInt32)); } else { decodeDictionaryIds(rowId, num, column, dictionaryIds); } @@ -370,6 +373,18 @@ private void decodeDictionaryIds( column.putInt(i, dictionary.decodeToInt(dictionaryIds.getDictId(i))); } } + } else if (column.dataType() == DataTypes.LongType) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType. + // For unsigned int32, it stores as dictionary encoded signed int32 in Parquet + // whenever dictionary is available. + // Here we eagerly decode it to the original signed int value then convert to + // long(unit32). + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + column.putLong(i, + Integer.toUnsignedLong(dictionary.decodeToInt(dictionaryIds.getDictId(i)))); + } + } } else if (column.dataType() == DataTypes.ByteType) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { @@ -565,6 +580,12 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) throw canReadAsIntDecimal(column.dataType())) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.LongType) { + // In `ParquetToSparkSchemaConverter`, we map parquet UINT32 to our LongType. + // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary fall backs. + // We read them as long values. + defColumn.readUnsignedIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 994779b618829..99beb0250a62b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -83,6 +83,15 @@ public final void readIntegers(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void readUnsignedIntegers(int total, WritableColumnVector c, int rowId) { + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, Integer.toUnsignedLong(buffer.getInt())); + } + } + // A fork of `readIntegers` to rebase the date values. For performance reasons, this method // iterates the values twice: check if we need to rebase first, then go to the optimized branch // if rebase is not needed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index a6c8292671d3f..384bcb30a17c7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -203,6 +203,41 @@ public void readIntegers( } } + // A fork of `readIntegers`, reading the signed integers as unsigned in long type + public void readUnsignedIntegers( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) throws IOException { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readUnsignedIntegers(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, Integer.toUnsignedLong(data.readInteger())); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + // A fork of `readIntegers`, which rebases the date int value (days) before filling // the Spark column vector. public void readIntegersWithRebase( @@ -602,6 +637,11 @@ public void readIntegers(int total, WritableColumnVector c, int rowId) { } } + @Override + public void readUnsignedIntegers(int total, WritableColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + @Override public void readIntegersWithRebase( int total, WritableColumnVector c, int rowId, boolean failIfRebase) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 35db8f235ed60..9f5d944329343 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -41,6 +41,7 @@ public interface VectorizedValuesReader { void readBytes(int total, WritableColumnVector c, int rowId); void readIntegers(int total, WritableColumnVector c, int rowId); void readIntegersWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); + void readUnsignedIntegers(int total, WritableColumnVector c, int rowId); void readLongs(int total, WritableColumnVector c, int rowId); void readLongsWithRebase(int total, WritableColumnVector c, int rowId, boolean failIfRebase); void readFloats(int total, WritableColumnVector c, int rowId); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index dca12ff6b4deb..2c610ec539ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -252,6 +252,11 @@ private[parquet] class ParquetRowConverter( updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { + case LongType if parquetType.getOriginalType == OriginalType.UINT_32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + updater.setLong(Integer.toUnsignedLong(value)) + } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new ParquetPrimitiveConverter(updater) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 1bf5891c3567e..ef094bdca0efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -131,13 +131,11 @@ class ParquetToSparkSchemaConverter( case INT32 => originalType match { case INT_8 => ByteType - case INT_16 => ShortType - case INT_32 | null => IntegerType + case INT_16 | UINT_8 => ShortType + case INT_32 | UINT_16 | null => IntegerType case DATE => DateType case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) - case UINT_8 => typeNotSupported() - case UINT_16 => typeNotSupported() - case UINT_32 => typeNotSupported() + case UINT_32 => LongType case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index fbe651502296c..82e605fc9fd14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -24,17 +24,16 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.column.{Encoding, ParquetProperties} -import org.apache.parquet.example.data.{Group, GroupWriter} -import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0 +import org.apache.parquet.example.data.Group +import org.apache.parquet.example.data.simple.{SimpleGroup, SimpleGroupFactory} import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.api.WriteSupport -import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.hadoop.metadata.CompressionCodecName.GZIP import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} @@ -49,26 +48,6 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport -// with an empty configuration (it is after all not intended to be used in this way?) -// and members are private so we need to make our own in order to pass the schema -// to the writer. -private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { - var groupWriter: GroupWriter = null - - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - groupWriter = new GroupWriter(recordConsumer, schema) - } - - override def init(configuration: Configuration): WriteContext = { - new WriteContext(schema, new java.util.HashMap[String, String]()) - } - - override def write(record: Group): Unit = { - groupWriter.write(record) - } -} - /** * A test suite that tests basic Parquet I/O. */ @@ -310,21 +289,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } - test("SPARK-10113 Support for unsigned Parquet logical types") { + test("SPARK-34817: Support for unsigned Parquet logical types") { val parquetSchema = MessageTypeParser.parseMessageType( """message root { - | required int32 c(UINT_32); + | required INT32 a(UINT_8); + | required INT32 b(UINT_16); + | required INT32 c(UINT_32); |} """.stripMargin) + val expectedSparkTypes = Seq(ShortType, IntegerType, LongType) + withTempPath { location => val path = new Path(location.getCanonicalPath) val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) - val errorMessage = intercept[Throwable] { - spark.read.parquet(path.toString).printSchema() - }.toString - assert(errorMessage.contains("Parquet type not supported")) + val sparkTypes = spark.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) } } @@ -381,9 +362,27 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkCompressionCodec(CompressionCodecName.SNAPPY) } + private def createParquetWriter( + schema: MessageType, + path: Path, + dictionaryEnabled: Boolean = false): ParquetWriter[Group] = { + val hadoopConf = spark.sessionState.newHadoopConf() + + ExampleParquetWriter + .builder(path) + .withDictionaryEncoding(dictionaryEnabled) + .withType(schema) + .withWriterVersion(PARQUET_1_0) + .withCompressionCodec(GZIP) + .withRowGroupSize(1024 * 1024) + .withPageSize(1024) + .withConf(hadoopConf) + .build() + } + test("read raw Parquet file") { def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( + val schemaStr = """ |message root { | required boolean _1; @@ -392,22 +391,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession | required float _4; | required double _5; |} - """.stripMargin) - - val testWriteSupport = new TestGroupWriteSupport(schema) - /** - * Provide a builder for constructing a parquet writer - after PARQUET-248 directly - * constructing the writer is deprecated and should be done through a builder. The default - * builders include Avro - but for raw Parquet writing we must create our own builder. - */ - class ParquetWriterBuilder() extends - ParquetWriter.Builder[Group, ParquetWriterBuilder](path) { - override def getWriteSupport(conf: Configuration) = testWriteSupport - - override def self() = this - } + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) - val writer = new ParquetWriterBuilder().build() + + val writer = createParquetWriter(schema, path) (0 until 10).foreach { i => val record = new SimpleGroup(schema) @@ -432,6 +420,45 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + test("SPARK-34817: Read UINT_8/UINT_16/UINT_32 from parquet") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path): Unit = { + val schemaStr = + """message root { + | required INT32 a(UINT_8); + | required INT32 b(UINT_16); + | required INT32 c(UINT_32); + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + (0 until 1000).foreach { i => + val group = factory.newGroup() + .append("a", i % 100 + Byte.MaxValue) + .append("b", i % 100 + Short.MaxValue) + .append("c", i % 100 + Int.MaxValue) + writer.write(group) + } + writer.close() + } + + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + readParquetFile(path.toString) { df => + checkAnswer(df, (0 until 1000).map { i => + Row(i % 100 + Byte.MaxValue, + i % 100 + Short.MaxValue, + i % 100 + Int.MaxValue.toLong) + }) + } + } + } + } + test("write metadata") { val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file =>