From b342196a546fa56b37ea9fd3b4b12169b043baaf Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 14 May 2017 19:33:15 -0700 Subject: [PATCH 01/20] [SPARK-20682][SPARK-15474][SPARK-21791] Add new ORCFileFormat based on Apache ORC 1.4.1 --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../execution/datasources/DataSource.scala | 3 +- .../datasources/orc/OrcFileFormat.scala | 154 +++++++- .../datasources/orc/OrcFilters.scala | 178 +++++++++ .../datasources/orc/OrcOptions.scala | 7 + .../datasources/orc/OrcOutputWriter.scala | 59 +++ .../execution/datasources/orc/OrcUtils.scala | 370 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +- .../sql/sources/DDLSourceLoadSuite.scala | 7 - ...pache.spark.sql.sources.DataSourceRegister | 1 - .../spark/sql/hive/HiveStrategies.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 320 --------------- .../spark/sql/hive/orc/OrcFileOperator.scala | 20 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 + .../spark/sql/hive/orc/OrcQuerySuite.scala | 17 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 16 +- 16 files changed, 820 insertions(+), 356 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0c5f3f22e31e..6cdfe2fae564 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,6 +1,7 @@ org.apache.spark.sql.execution.datasources.csv.CSVFileFormat org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat +org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b43d282bd434..fdf113d36b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ @@ -536,7 +537,7 @@ object DataSource extends Logging { val parquet = classOf[ParquetFileFormat].getCanonicalName val csv = classOf[CSVFileFormat].getCanonicalName val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" - val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" + val orc = classOf[OrcFileFormat].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 215740e90fe8..9f57efb6b455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -17,10 +17,29 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.orc.TypeDescription +import java.io._ +import java.net.URI +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc._ +import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce._ + +import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration private[sql] object OrcFileFormat { private def checkFieldName(name: String): Unit = { @@ -39,3 +58,134 @@ private[sql] object OrcFileFormat { names.foreach(checkFieldName) } } + +class DefaultSource extends OrcFileFormat + +/** + * New ORC File Format based on Apache ORC 1.4.1 and above. + */ +class OrcFileFormat + extends FileFormat + with DataSourceRegister + with Serializable { + + override def shortName(): String = "orc" + + override def toString: String = "ORC_1.4" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat] + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcUtils.readSchema(sparkSession, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + + val conf = job.getConfiguration + + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.getSchemaString(dataSchema)) + + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + + conf.asInstanceOf[JobConf] + .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcOptions.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) + } + } + + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val resolver = sparkSession.sessionState.conf.resolver + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val maybeMissingSchema = OrcUtils.getMissingSchema( + resolver, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) + if (maybeMissingSchema.isEmpty) { + Iterator.empty + } else { + val missingSchema = maybeMissingSchema.get + val columns = requiredSchema + .filter(f => missingSchema.getFieldIndex(f.name).isEmpty) + .map(f => dataSchema.fieldIndex(f.name)).mkString(",") + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, columns) + + val fileSplit = + new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val partitionValues = file.partitionValues + + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val mutableRow = new SpecificInternalRow(resultSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(resultSchema) + + // Initialize the partition column values once. + for (i <- requiredSchema.length until resultSchema.length) { + val value = partitionValues.get(i - requiredSchema.length, resultSchema(i).dataType) + mutableRow.update(i, value) + } + + val valueWrappers = requiredSchema.fields.map(f => OrcUtils.getValueWrapper(f.dataType)) + iter.map { value => + unsafeProjection(OrcUtils.convertOrcStructToInternalRow(value, dataSchema, requiredSchema, + maybeMissingSchema, Some(valueWrappers), Some(mutableRow))) + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala new file mode 100644 index 000000000000..6a8d3739892b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable + +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types._ + +/** + * Utility functions to convert Spark data source filters to ORC filters. + */ +private[orc] object OrcFilters { + + /** + * Create ORC filter as a SearchArgument instance. + */ + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + } yield filter + + for { + conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() + } + + /** + * Return true if this is a searchable type in ORC. + */ + private def isSearchableType(dataType: DataType) = dataType match { + case ByteType | ShortType | FloatType | DoubleType => true + case IntegerType | LongType | StringType | BooleanType => true + case TimestampType | _: DecimalType => true + case _ => false + } + + /** + * Get PredicateLeafType which is corresponding to the given DataType. + */ + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + } + + /** + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + val decimal = value.asInstanceOf[java.math.BigDecimal] + val decimalWritable = new HiveDecimalWritable(decimal.longValue) + decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) + decimalWritable + case _ => value + } + + /** + * Build a SearchArgument and return the builder so far. + */ + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { + def newBuilder = SearchArgumentFactory.newBuilder() + + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + + import org.apache.spark.sql.sources._ + + expression match { + case And(left, right) => + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Or(left, right) => + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(dataTypeMap, child, newBuilder) + negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) + } yield negate.end() + + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end()) + + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end()) + + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end()) + + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end()) + + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startAnd().isNull(attribute, getType(attribute)).end()) + + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startNot().isNull(attribute, getType(attribute)).end()) + + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(attribute, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) + + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a52..19da412f08ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,11 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala new file mode 100644 index 000000000000..e9512bb6d478 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce._ +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcOutputFormat + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types.StructType + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + private lazy val orcStruct: OrcStruct = + OrcUtils.createOrcValue(dataSchema).asInstanceOf[OrcStruct] + + private[this] val writableWrappers = + dataSchema.fields.map(f => OrcUtils.getWritableWrapper(f.dataType)) + + private val recordWriter = { + new OrcOutputFormat[OrcStruct]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + }.getRecordWriter(context) + } + + override def write(row: InternalRow): Unit = { + recordWriter.write( + NullWritable.get, + OrcUtils.convertInternalRowToOrcStruct( + row, dataSchema, Some(writableWrappers), Some(orcStruct))) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala new file mode 100644 index 000000000000..00b7ecf3f3d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.io.IOException + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io._ +import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object OrcUtils extends Logging { + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDirectory) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + paths + } + + private[orc] def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + try { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case _: IOException => None + } + } + + private[orc] def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) + : Option[StructType] = { + val conf = sparkSession.sparkContext.hadoopConfiguration + files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + } + } + + private[orc] def getSchemaString(schema: StructType): String = { + schema.fields.map(f => s"${f.name}:${f.dataType.catalogString}").mkString("struct<", ",", ">") + } + + private[orc] def getTypeDescription(dataType: DataType) = dataType match { + case st: StructType => TypeDescription.fromString(getSchemaString(st)) + case _ => TypeDescription.fromString(dataType.catalogString) + } + + /** + * Return a missing schema in a give ORC file. + */ + private[orc] def getMissingSchema( + resolver: Resolver, + dataSchema: StructType, + partitionSchema: StructType, + file: Path, + conf: Configuration): Option[StructType] = { + try { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + val orcSchema = if (schema.getFieldNames.asScala.forall(_.startsWith("_col"))) { + logInfo("Recover ORC schema with data schema") + var schemaString = schema.toString + dataSchema.zipWithIndex.foreach { case (field: StructField, index: Int) => + schemaString = schemaString.replace(s"_col$index:", s"${field.name}:") + } + TypeDescription.fromString(schemaString) + } else { + schema + } + + var missingSchema = new StructType + if (dataSchema.length > orcSchema.getFieldNames.size) { + dataSchema.filter(x => partitionSchema.getFieldIndex(x.name).isEmpty).foreach { f => + if (!orcSchema.getFieldNames.asScala.exists(resolver(_, f.name))) { + missingSchema = missingSchema.add(f) + } + } + } + Some(missingSchema) + } + } catch { + case _: IOException => None + } + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private[orc] def createOrcValue(dataType: DataType) = + OrcStruct.createValue(getTypeDescription(dataType)) + + /** + * Convert Apache ORC OrcStruct to Apache Spark InternalRow. + * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. + */ + private[orc] def convertOrcStructToInternalRow( + orcStruct: OrcStruct, + dataSchema: StructType, + requiredSchema: StructType, + missingSchema: Option[StructType] = None, + valueWrappers: Option[Seq[Any => Any]] = None, + internalRow: Option[InternalRow] = None): InternalRow = { + val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) + val wrappers = + valueWrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(getValueWrapper).toSeq) + var i = 0 + val len = requiredSchema.length + val names = orcStruct.getSchema.getFieldNames + while (i < len) { + val name = requiredSchema(i).name + val writable = if (missingSchema.isEmpty || missingSchema.get.getFieldIndex(name).isEmpty) { + if (names.contains(name)) { + orcStruct.getFieldValue(name) + } else { + orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) + } + } else { + null + } + if (writable == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = wrappers(i)(writable) + } + i += 1 + } + mutableRow + } + + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + + /** + * Builds a catalyst-value return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[orc] def getValueWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + + case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) + + case ByteType => withNullSafe(o => o.asInstanceOf[ByteWritable].get) + case ShortType => withNullSafe(o => o.asInstanceOf[ShortWritable].get) + case IntegerType => withNullSafe(o => o.asInstanceOf[IntWritable].get) + case LongType => withNullSafe(o => o.asInstanceOf[LongWritable].get) + + case FloatType => withNullSafe(o => o.asInstanceOf[FloatWritable].get) + case DoubleType => withNullSafe(o => o.asInstanceOf[DoubleWritable].get) + + case StringType => + withNullSafe(o => UTF8String.fromBytes(o.asInstanceOf[Text].copyBytes)) + + case BinaryType => + withNullSafe { o => + val binary = o.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + bytes + } + + case DateType => + withNullSafe(o => DateTimeUtils.fromJavaDate(o.asInstanceOf[DateWritable].get)) + case TimestampType => + withNullSafe(o => DateTimeUtils.fromJavaTimestamp(o.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => + withNullSafe { o => + val decimal = o.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + v + } + + case _: StructType => + withNullSafe { o => + val structValue = convertOrcStructToInternalRow( + o.asInstanceOf[OrcStruct], + dataType.asInstanceOf[StructType], + dataType.asInstanceOf[StructType]) + structValue + } + + case ArrayType(elementType, _) => + withNullSafe { o => + val wrapper = getValueWrapper(elementType) + val data = new ArrayBuffer[Any] + o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => + data += wrapper(x) + } + new GenericArrayData(data.toArray) + } + + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getValueWrapper(keyType) + val valueWrapper = getValueWrapper(valueType) + val map = new java.util.TreeMap[Any, Any] + o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + .entrySet().asScala.foreach { entry => + map.put(keyWrapper(entry.getKey), valueWrapper(entry.getValue)) + } + ArrayBasedMapData(map.asScala) + } + + case udt: UserDefinedType[_] => + withNullSafe { o => getValueWrapper(udt.sqlType)(o) } + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + /** + * Convert Apache Spark InternalRow to Apache ORC OrcStruct. + */ + private[orc] def convertInternalRowToOrcStruct( + row: InternalRow, + schema: StructType, + valueWrappers: Option[Seq[Any => Any]] = None, + struct: Option[OrcStruct] = None): OrcStruct = { + val wrappers = + valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getWritableWrapper).toSeq) + val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct]) + + for (schemaIndex <- 0 until schema.length) { + val fieldType = schema(schemaIndex).dataType + if (row.isNullAt(schemaIndex)) { + orcStruct.setFieldValue(schemaIndex, null) + } else { + val field = row.get(schemaIndex, fieldType) + val fieldValue = wrappers(schemaIndex)(field).asInstanceOf[WritableComparable[_]] + orcStruct.setFieldValue(schemaIndex, fieldValue) + } + } + orcStruct + } + + /** + * Builds a WritableComparable-return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[orc] def getWritableWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + + case BooleanType => withNullSafe(o => new BooleanWritable(o.asInstanceOf[Boolean])) + + case ByteType => withNullSafe(o => new ByteWritable(o.asInstanceOf[Byte])) + case ShortType => withNullSafe(o => new ShortWritable(o.asInstanceOf[Short])) + case IntegerType => withNullSafe(o => new IntWritable(o.asInstanceOf[Int])) + case LongType => withNullSafe(o => new LongWritable(o.asInstanceOf[Long])) + + case FloatType => withNullSafe(o => new FloatWritable(o.asInstanceOf[Float])) + case DoubleType => withNullSafe(o => new DoubleWritable(o.asInstanceOf[Double])) + + case StringType => withNullSafe(o => new Text(o.asInstanceOf[UTF8String].getBytes)) + + case BinaryType => withNullSafe(o => new BytesWritable(o.asInstanceOf[Array[Byte]])) + + case DateType => + withNullSafe(o => new DateWritable(DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))) + case TimestampType => + withNullSafe { o => + val us = o.asInstanceOf[Long] + var seconds = us / DateTimeUtils.MICROS_PER_SECOND + var micros = us % DateTimeUtils.MICROS_PER_SECOND + if (micros < 0) { + micros += DateTimeUtils.MICROS_PER_SECOND + seconds -= 1 + } + val t = new OrcTimestamp(seconds * 1000) + t.setNanos(micros.toInt * 1000) + t + } + + case _: DecimalType => + withNullSafe { o => + new HiveDecimalWritable(HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + } + + case st: StructType => + withNullSafe(o => convertInternalRowToOrcStruct(o.asInstanceOf[InternalRow], st)) + + case ArrayType(et, _) => + withNullSafe { o => + val data = o.asInstanceOf[ArrayData] + val list = createOrcValue(dataType) + for (i <- 0 until data.numElements()) { + val d = data.get(i, et) + val v = getWritableWrapper(et)(d).asInstanceOf[WritableComparable[_]] + list.asInstanceOf[OrcList[WritableComparable[_]]].add(v) + } + list + } + + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getWritableWrapper(keyType) + val valueWrapper = getWritableWrapper(valueType) + val data = o.asInstanceOf[MapData] + val map = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + data.foreach(keyType, valueType, { case (k, v) => + map.put( + keyWrapper(k).asInstanceOf[WritableComparable[_]], + valueWrapper(v).asInstanceOf[WritableComparable[_]]) + }) + map + } + + case udt: UserDefinedType[_] => + withNullSafe { o => + val udtRow = new SpecificInternalRow(Seq(udt.sqlType)) + udtRow(0) = o + convertInternalRowToOrcStruct( + udtRow, + StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0) + } + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5d0bba69daca..9f4469a09ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1661,11 +1661,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Path does not exist")) - e = intercept[AnalysisException] { - sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") - } - assert(e.message.contains("The ORC data source must be used with Hive support enabled")) - e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") } @@ -2757,4 +2752,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + // Only New OrcFileFormat supports this. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF.limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 3ce6ae3c5292..f22d843bfabd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -53,13 +53,6 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))) } - - test("should fail to load ORC without Hive Support") { - val e = intercept[AnalysisException] { - spark.read.format("orc").load() - } - assert(e.message.contains("The ORC data source must be used with Hive support enabled")) - } } diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e7d762fbebe7..d73a2e5dbeae 100644 --- a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1 @@ -org.apache.spark.sql.hive.orc.OrcFileFormat org.apache.spark.sql.hive.execution.HiveFileFormat \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ee1f6ee17306..7c41cba33623 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala deleted file mode 100644 index 3b33a9ff082f..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import java.net.URI -import java.util.Properties - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc._ -import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} -import org.apache.orc.OrcConf.COMPRESS - -import org.apache.spark.TaskContext -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.orc.OrcOptions -import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} -import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration - -/** - * `FileFormat` for reading ORC files. If this is moved or renamed, please update - * `DataSource`'s backwardCompatibilityMap. - */ -class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable { - - override def shortName(): String = "orc" - - override def toString: String = "ORC" - - override def inferSchema( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = { - OrcFileOperator.readSchema( - files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) - ) - } - - override def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) - - val configuration = job.getConfiguration - - configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) - configuration match { - case conf: JobConf => - conf.setOutputFormat(classOf[OrcOutputFormat]) - case conf => - conf.setClass( - "mapred.output.format.class", - classOf[OrcOutputFormat], - classOf[MapRedOutputFormat[_, _]]) - } - - new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) - } - - override def getFileExtension(context: TaskAttemptContext): String = { - val compressionExtension: String = { - val name = context.getConfiguration.get(COMPRESS.getAttribute) - OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") - } - - compressionExtension + ".orc" - } - } - } - - override def isSplitable( - sparkSession: SparkSession, - options: Map[String, String], - path: Path): Boolean = { - true - } - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - if (sparkSession.sessionState.conf.orcFilterPushDown) { - // Sets pushed predicates - OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) - hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) - } - } - - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - (file: PartitionedFile) => { - val conf = broadcastedHadoopConf.value.value - - // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this - // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file - // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty - if (isEmptyFile) { - Iterator.empty - } else { - OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) - - val orcRecordReader = { - val job = Job.getInstance(conf) - FileInputFormat.setInputPaths(job, file.filePath) - - val fileSplit = new FileSplit( - new Path(new URI(file.filePath)), file.start, file.length, Array.empty - ) - // Custom OrcRecordReader is used to get - // ObjectInspector during recordReader creation itself and can - // avoid NameNode call in unwrapOrcStructs per file. - // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) - new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) - } - - val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) - - // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcFileFormat.unwrapOrcStructs( - conf, - dataSchema, - requiredSchema, - Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - recordsIterator) - } - } - } -} - -private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) - extends HiveInspectors { - - def serialize(row: InternalRow): Writable = { - wrapOrcStruct(cachedOrcStruct, structOI, row) - serializer.serialize(cachedOrcStruct, structOI) - } - - private[this] val serializer = { - val table = new Properties() - table.setProperty("columns", dataSchema.fieldNames.mkString(",")) - table.setProperty("columns.types", dataSchema.map(_.dataType.catalogString).mkString(":")) - - val serde = new OrcSerde - serde.initialize(conf, table) - serde - } - - // Object inspector converted from the schema of the relation to be serialized. - private[this] val structOI = { - val typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(dataSchema.catalogString) - OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) - .asInstanceOf[SettableStructObjectInspector] - } - - private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] - - // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format - private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { - case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) - } - - private[this] def wrapOrcStruct( - struct: OrcStruct, - oi: SettableStructObjectInspector, - row: InternalRow): Unit = { - val fieldRefs = oi.getAllStructFieldRefs - var i = 0 - val size = fieldRefs.size - while (i < size) { - - oi.setStructFieldData( - struct, - fieldRefs.get(i), - wrappers(i)(row.get(i, dataSchema(i).dataType)) - ) - i += 1 - } - } -} - -private[orc] class OrcOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext) - extends OutputWriter { - - private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) - - // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this - // flag to decide whether `OrcRecordWriter.close()` needs to be called. - private var recordWriterInstantiated = false - - private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { - recordWriterInstantiated = true - new OrcOutputFormat().getRecordWriter( - new Path(path).getFileSystem(context.getConfiguration), - context.getConfiguration.asInstanceOf[JobConf], - path, - Reporter.NULL - ).asInstanceOf[RecordWriter[NullWritable, Writable]] - } - - override def write(row: InternalRow): Unit = { - recordWriter.write(NullWritable.get(), serializer.serialize(row)) - } - - override def close(): Unit = { - if (recordWriterInstantiated) { - recordWriter.close(Reporter.NULL) - } - } -} - -private[orc] object OrcFileFormat extends HiveInspectors { - // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. - private[orc] val SARG_PUSHDOWN = "sarg.pushdown" - - // The extensions for ORC compression codecs - val extensionsForCompressionCodecNames = Map( - "NONE" -> "", - "SNAPPY" -> ".snappy", - "ZLIB" -> ".zlib", - "LZO" -> ".lzo") - - def unwrapOrcStructs( - conf: Configuration, - dataSchema: StructType, - requiredSchema: StructType, - maybeStructOI: Option[StructObjectInspector], - iterator: Iterator[Writable]): Iterator[InternalRow] = { - val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(requiredSchema) - - def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { - case (field, ordinal) => - var ref = oi.getStructFieldRef(field.name) - if (ref == null) { - ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) - } - ref -> ordinal - }.unzip - - val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) - - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - val length = fieldRefs.length - while (i < length) { - val fieldRef = fieldRefs(i) - val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) - } - i += 1 - } - unsafeProjection(mutableRow) - } - } - - maybeStructOI.map(unwrap).getOrElse(Iterator.empty) - } - - def setRequiredColumns( - conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) - val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c..aa0be0630dc3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -22,9 +22,10 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.datasources.orc.OrcUtils +import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.types.StructType private[hive] object OrcFileOperator extends Logging { @@ -64,7 +65,7 @@ private[hive] object OrcFileOperator extends Logging { hdfsPath.getFileSystem(conf) } - listOrcFiles(basePath, conf).iterator.map { path => + OrcUtils.listOrcFiles(basePath, conf).iterator.map { path => path -> OrcFile.createReader(fs, path) }.collectFirst { case (path, reader) if isWithNonEmptySchema(path, reader) => reader @@ -87,15 +88,10 @@ private[hive] object OrcFileOperator extends Logging { getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) } - def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { - // TODO: Check if the paths coming in are already qualified and simplify. - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) - .filterNot(_.isDirectory) - .map(_.getPath) - .filterNot(_.getName.startsWith("_")) - .filterNot(_.getName.startsWith(".")) - paths + def setRequiredColumns( + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) + val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index ba0a7605da71..c7c1264dca5d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.catalog.CatalogUtils +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 1fa9091f967a..021c8c495854 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.orc +import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll @@ -59,6 +61,14 @@ case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + private def getFileReader(path: String, extensions: String) = { + import org.apache.orc.OrcFile + val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(extensions)) + assert(maybeOrcFile.isDefined) + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + OrcFile.createReader(orcFilePath, OrcFile.readerOptions(new Configuration())) + } + test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) @@ -230,14 +240,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - // Following codec is not supported in Hive 1.2.1, ignore it now - ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { + test("LZO compression options for writing to an ORC file") { withTempPath { file => spark.range(0, 10).write .option("compression", "LZO") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".lzo.orc").getCompressionKind assert("LZO" === expectedCompressionKind.name()) } } @@ -599,7 +608,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileOperator.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 2a086be57f51..1f8ee0becbcb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -225,13 +226,13 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } } -class OrcSourceSuite extends OrcSuite { +class OrcSourceSuite extends OrcSuite with SQLTestUtils { override def beforeAll(): Unit = { super.beforeAll() spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_source - |USING org.apache.spark.sql.hive.orc + |USING orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -239,7 +240,7 @@ class OrcSourceSuite extends OrcSuite { spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_as_source - |USING org.apache.spark.sql.hive.orc + |USING orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -278,4 +279,13 @@ class OrcSourceSuite extends OrcSuite { )).get.toString } } + + test("SPARK-21791 ORC should support column names with dot") { + import spark.implicits._ + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.orc(path) + assert(spark.read.orc(path).collect().length == 2) + } + } } From ca78ac7adc8c4ac5b5e5817afd1cd947bb811640 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 7 Nov 2017 02:51:03 -0800 Subject: [PATCH 02/20] Address comments. --- .../datasources/orc/OrcDeserializer.scala | 167 ++++++++++++ .../datasources/orc/OrcFileFormat.scala | 24 +- .../datasources/orc/OrcFilters.scala | 2 + .../datasources/orc/OrcOptions.scala | 7 - .../datasources/orc/OrcOutputWriter.scala | 14 +- .../datasources/orc/OrcSerializer.scala | 163 +++++++++++ .../execution/datasources/orc/OrcUtils.scala | 255 +----------------- 7 files changed, 349 insertions(+), 283 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala new file mode 100644 index 000000000000..b934aea79c5e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.io._ +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[orc] class OrcDeserializer( + dataSchema: StructType, + requiredSchema: StructType, + maybeMissingSchema: Option[StructType]) { + + private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + + private[this] val valueWrappers = requiredSchema.fields.map(f => getValueWrapper(f.dataType)) + + def deserialize(writable: OrcStruct): InternalRow = { + convertOrcStructToInternalRow(writable, dataSchema, requiredSchema, + maybeMissingSchema, Some(valueWrappers), Some(mutableRow)) + } + + /** + * Convert Apache ORC OrcStruct to Apache Spark InternalRow. + * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. + */ + private[this] def convertOrcStructToInternalRow( + orcStruct: OrcStruct, + dataSchema: StructType, + requiredSchema: StructType, + missingSchema: Option[StructType] = None, + valueWrappers: Option[Seq[Any => Any]] = None, + internalRow: Option[InternalRow] = None): InternalRow = { + val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) + val wrappers = + valueWrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(getValueWrapper).toSeq) + var i = 0 + val len = requiredSchema.length + val names = orcStruct.getSchema.getFieldNames + while (i < len) { + val name = requiredSchema(i).name + val writable = if (missingSchema.isEmpty || missingSchema.get.getFieldIndex(name).isEmpty) { + if (names.contains(name)) { + orcStruct.getFieldValue(name) + } else { + orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) + } + } else { + null + } + if (writable == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = wrappers(i)(writable) + } + i += 1 + } + mutableRow + } + + private[this] def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + + /** + * Builds a catalyst-value return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[this] def getValueWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + + case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) + + case ByteType => withNullSafe(o => o.asInstanceOf[ByteWritable].get) + case ShortType => withNullSafe(o => o.asInstanceOf[ShortWritable].get) + case IntegerType => withNullSafe(o => o.asInstanceOf[IntWritable].get) + case LongType => withNullSafe(o => o.asInstanceOf[LongWritable].get) + + case FloatType => withNullSafe(o => o.asInstanceOf[FloatWritable].get) + case DoubleType => withNullSafe(o => o.asInstanceOf[DoubleWritable].get) + + case StringType => + withNullSafe(o => UTF8String.fromBytes(o.asInstanceOf[Text].copyBytes)) + + case BinaryType => + withNullSafe { o => + val binary = o.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + bytes + } + + case DateType => + withNullSafe(o => DateTimeUtils.fromJavaDate(o.asInstanceOf[DateWritable].get)) + case TimestampType => + withNullSafe(o => DateTimeUtils.fromJavaTimestamp(o.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => + withNullSafe { o => + val decimal = o.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + v + } + + case _: StructType => + withNullSafe { o => + val structValue = convertOrcStructToInternalRow( + o.asInstanceOf[OrcStruct], + dataType.asInstanceOf[StructType], + dataType.asInstanceOf[StructType]) + structValue + } + + case ArrayType(elementType, _) => + withNullSafe { o => + val wrapper = getValueWrapper(elementType) + val data = new ArrayBuffer[Any] + o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => + data += wrapper(x) + } + new GenericArrayData(data.toArray) + } + + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getValueWrapper(keyType) + val valueWrapper = getValueWrapper(valueType) + val map = new java.util.TreeMap[Any, Any] + o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + .entrySet().asScala.foreach { entry => + map.put(keyWrapper(entry.getKey), valueWrapper(entry.getValue)) + } + ArrayBasedMapData(map.asScala) + } + + case udt: UserDefinedType[_] => + withNullSafe { o => getValueWrapper(udt.sqlType)(o) } + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 9f57efb6b455..ea9fb3bd377c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -111,7 +111,7 @@ class OrcFileFormat override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { val name = context.getConfiguration.get(COMPRESS.getAttribute) - OrcOptions.extensionsForCompressionCodecNames.getOrElse(name, "") + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -126,7 +126,7 @@ class OrcFileFormat true } - override def buildReaderWithPartitionValues( + override def buildReader( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -162,29 +162,15 @@ class OrcFileFormat new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val partitionValues = file.partitionValues - - val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - val mutableRow = new SpecificInternalRow(resultSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(resultSchema) - - // Initialize the partition column values once. - for (i <- requiredSchema.length until resultSchema.length) { - val value = partitionValues.get(i - requiredSchema.length, resultSchema(i).dataType) - mutableRow.update(i, value) - } - - val valueWrappers = requiredSchema.fields.map(f => OrcUtils.getValueWrapper(f.dataType)) - iter.map { value => - unsafeProjection(OrcUtils.convertOrcStructToInternalRow(value, dataSchema, requiredSchema, - maybeMissingSchema, Some(valueWrappers), Some(mutableRow))) - } + val unsafeProjection = UnsafeProjection.create(requiredSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, maybeMissingSchema) + iter.map(value => unsafeProjection(deserializer.deserialize(value))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 6a8d3739892b..de920c8180e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -35,6 +35,8 @@ private[orc] object OrcFilters { def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index 19da412f08ee..c866dd834a52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,11 +67,4 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") - - // The extensions for ORC compression codecs - val extensionsForCompressionCodecNames = Map( - "NONE" -> "", - "SNAPPY" -> ".snappy", - "ZLIB" -> ".zlib", - "LZO" -> ".lzo") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index e9512bb6d478..84755bfa301f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -19,24 +19,21 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcOutputFormat import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - private lazy val orcStruct: OrcStruct = - OrcUtils.createOrcValue(dataSchema).asInstanceOf[OrcStruct] - private[this] val writableWrappers = - dataSchema.fields.map(f => OrcUtils.getWritableWrapper(f.dataType)) + private[this] val serializer = new OrcSerializer(dataSchema) private val recordWriter = { new OrcOutputFormat[OrcStruct]() { @@ -47,10 +44,7 @@ private[orc] class OrcOutputWriter( } override def write(row: InternalRow): Unit = { - recordWriter.write( - NullWritable.get, - OrcUtils.convertInternalRowToOrcStruct( - row, dataSchema, Some(writableWrappers), Some(orcStruct))) + recordWriter.write(NullWritable.get(), serializer.serialize(row)) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala new file mode 100644 index 000000000000..9d605f1fb309 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.orc.OrcUtils.getTypeDescription +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[orc] class OrcSerializer(dataSchema: StructType) { + + private[this] lazy val orcStruct: OrcStruct = + createOrcValue(dataSchema).asInstanceOf[OrcStruct] + + private[this] val writableWrappers = + dataSchema.fields.map(f => getWritableWrapper(f.dataType)) + + def serialize(row: InternalRow): OrcStruct = { + convertInternalRowToOrcStruct(row, dataSchema, Some(writableWrappers), Some(orcStruct)) + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private[this] def createOrcValue(dataType: DataType) = + OrcStruct.createValue(getTypeDescription(dataType)) + + /** + * Convert Apache Spark InternalRow to Apache ORC OrcStruct. + */ + private[this] def convertInternalRowToOrcStruct( + row: InternalRow, + schema: StructType, + valueWrappers: Option[Seq[Any => Any]] = None, + struct: Option[OrcStruct] = None): OrcStruct = { + val wrappers = + valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getWritableWrapper).toSeq) + val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct]) + + for (schemaIndex <- 0 until schema.length) { + val fieldType = schema(schemaIndex).dataType + if (row.isNullAt(schemaIndex)) { + orcStruct.setFieldValue(schemaIndex, null) + } else { + val field = row.get(schemaIndex, fieldType) + val fieldValue = wrappers(schemaIndex)(field).asInstanceOf[WritableComparable[_]] + orcStruct.setFieldValue(schemaIndex, fieldValue) + } + } + orcStruct + } + + private[this] def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + + /** + * Builds a WritableComparable-return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[this] def getWritableWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + + case BooleanType => withNullSafe(o => new BooleanWritable(o.asInstanceOf[Boolean])) + + case ByteType => withNullSafe(o => new ByteWritable(o.asInstanceOf[Byte])) + case ShortType => withNullSafe(o => new ShortWritable(o.asInstanceOf[Short])) + case IntegerType => withNullSafe(o => new IntWritable(o.asInstanceOf[Int])) + case LongType => withNullSafe(o => new LongWritable(o.asInstanceOf[Long])) + + case FloatType => withNullSafe(o => new FloatWritable(o.asInstanceOf[Float])) + case DoubleType => withNullSafe(o => new DoubleWritable(o.asInstanceOf[Double])) + + case StringType => withNullSafe(o => new Text(o.asInstanceOf[UTF8String].getBytes)) + + case BinaryType => withNullSafe(o => new BytesWritable(o.asInstanceOf[Array[Byte]])) + + case DateType => + withNullSafe(o => new DateWritable(DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))) + case TimestampType => + withNullSafe { o => + val us = o.asInstanceOf[Long] + var seconds = us / DateTimeUtils.MICROS_PER_SECOND + var micros = us % DateTimeUtils.MICROS_PER_SECOND + if (micros < 0) { + micros += DateTimeUtils.MICROS_PER_SECOND + seconds -= 1 + } + val t = new OrcTimestamp(seconds * 1000) + t.setNanos(micros.toInt * 1000) + t + } + + case _: DecimalType => + withNullSafe { o => + new HiveDecimalWritable(HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + } + + case st: StructType => + withNullSafe(o => convertInternalRowToOrcStruct(o.asInstanceOf[InternalRow], st)) + + case ArrayType(et, _) => + withNullSafe { o => + val data = o.asInstanceOf[ArrayData] + val list = createOrcValue(dataType) + for (i <- 0 until data.numElements()) { + val d = data.get(i, et) + val v = getWritableWrapper(et)(d).asInstanceOf[WritableComparable[_]] + list.asInstanceOf[OrcList[WritableComparable[_]]].add(v) + } + list + } + + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getWritableWrapper(keyType) + val valueWrapper = getWritableWrapper(valueType) + val data = o.asInstanceOf[MapData] + val map = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + data.foreach(keyType, valueType, { case (k, v) => + map.put( + keyWrapper(k).asInstanceOf[WritableComparable[_]], + valueWrapper(v).asInstanceOf[WritableComparable[_]]) + }) + map + } + + case udt: UserDefinedType[_] => + withNullSafe { o => + val udtRow = new SpecificInternalRow(Seq(udt.sqlType)) + udtRow(0) = o + convertInternalRowToOrcStruct( + udtRow, + StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0) + } + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 00b7ecf3f3d0..747a759772e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -20,29 +20,27 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.IOException import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io._ import org.apache.orc.{OrcFile, TypeDescription} -import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} -import org.apache.orc.storage.common.`type`.HiveDecimal -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String object OrcUtils extends Logging { + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) @@ -72,7 +70,7 @@ object OrcUtils extends Logging { private[orc] def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { - val conf = sparkSession.sparkContext.hadoopConfiguration + val conf = sparkSession.sessionState.newHadoopConf() files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] @@ -130,241 +128,4 @@ object OrcUtils extends Logging { case _: IOException => None } } - - /** - * Return a Orc value object for the given Spark schema. - */ - private[orc] def createOrcValue(dataType: DataType) = - OrcStruct.createValue(getTypeDescription(dataType)) - - /** - * Convert Apache ORC OrcStruct to Apache Spark InternalRow. - * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. - */ - private[orc] def convertOrcStructToInternalRow( - orcStruct: OrcStruct, - dataSchema: StructType, - requiredSchema: StructType, - missingSchema: Option[StructType] = None, - valueWrappers: Option[Seq[Any => Any]] = None, - internalRow: Option[InternalRow] = None): InternalRow = { - val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) - val wrappers = - valueWrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(getValueWrapper).toSeq) - var i = 0 - val len = requiredSchema.length - val names = orcStruct.getSchema.getFieldNames - while (i < len) { - val name = requiredSchema(i).name - val writable = if (missingSchema.isEmpty || missingSchema.get.getFieldIndex(name).isEmpty) { - if (names.contains(name)) { - orcStruct.getFieldValue(name) - } else { - orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) - } - } else { - null - } - if (writable == null) { - mutableRow.setNullAt(i) - } else { - mutableRow(i) = wrappers(i)(writable) - } - i += 1 - } - mutableRow - } - - private def withNullSafe(f: Any => Any): Any => Any = { - input => if (input == null) null else f(input) - } - - /** - * Builds a catalyst-value return function ahead of time according to DataType - * to avoid pattern matching and branching costs per row. - */ - private[orc] def getValueWrapper(dataType: DataType): Any => Any = dataType match { - case NullType => _ => null - - case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) - - case ByteType => withNullSafe(o => o.asInstanceOf[ByteWritable].get) - case ShortType => withNullSafe(o => o.asInstanceOf[ShortWritable].get) - case IntegerType => withNullSafe(o => o.asInstanceOf[IntWritable].get) - case LongType => withNullSafe(o => o.asInstanceOf[LongWritable].get) - - case FloatType => withNullSafe(o => o.asInstanceOf[FloatWritable].get) - case DoubleType => withNullSafe(o => o.asInstanceOf[DoubleWritable].get) - - case StringType => - withNullSafe(o => UTF8String.fromBytes(o.asInstanceOf[Text].copyBytes)) - - case BinaryType => - withNullSafe { o => - val binary = o.asInstanceOf[BytesWritable] - val bytes = new Array[Byte](binary.getLength) - System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) - bytes - } - - case DateType => - withNullSafe(o => DateTimeUtils.fromJavaDate(o.asInstanceOf[DateWritable].get)) - case TimestampType => - withNullSafe(o => DateTimeUtils.fromJavaTimestamp(o.asInstanceOf[OrcTimestamp])) - - case DecimalType.Fixed(precision, scale) => - withNullSafe { o => - val decimal = o.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) - v.changePrecision(precision, scale) - v - } - - case _: StructType => - withNullSafe { o => - val structValue = convertOrcStructToInternalRow( - o.asInstanceOf[OrcStruct], - dataType.asInstanceOf[StructType], - dataType.asInstanceOf[StructType]) - structValue - } - - case ArrayType(elementType, _) => - withNullSafe { o => - val wrapper = getValueWrapper(elementType) - val data = new ArrayBuffer[Any] - o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => - data += wrapper(x) - } - new GenericArrayData(data.toArray) - } - - case MapType(keyType, valueType, _) => - withNullSafe { o => - val keyWrapper = getValueWrapper(keyType) - val valueWrapper = getValueWrapper(valueType) - val map = new java.util.TreeMap[Any, Any] - o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - .entrySet().asScala.foreach { entry => - map.put(keyWrapper(entry.getKey), valueWrapper(entry.getValue)) - } - ArrayBasedMapData(map.asScala) - } - - case udt: UserDefinedType[_] => - withNullSafe { o => getValueWrapper(udt.sqlType)(o) } - - case _ => - throw new UnsupportedOperationException(s"$dataType is not supported yet.") - } - - /** - * Convert Apache Spark InternalRow to Apache ORC OrcStruct. - */ - private[orc] def convertInternalRowToOrcStruct( - row: InternalRow, - schema: StructType, - valueWrappers: Option[Seq[Any => Any]] = None, - struct: Option[OrcStruct] = None): OrcStruct = { - val wrappers = - valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getWritableWrapper).toSeq) - val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct]) - - for (schemaIndex <- 0 until schema.length) { - val fieldType = schema(schemaIndex).dataType - if (row.isNullAt(schemaIndex)) { - orcStruct.setFieldValue(schemaIndex, null) - } else { - val field = row.get(schemaIndex, fieldType) - val fieldValue = wrappers(schemaIndex)(field).asInstanceOf[WritableComparable[_]] - orcStruct.setFieldValue(schemaIndex, fieldValue) - } - } - orcStruct - } - - /** - * Builds a WritableComparable-return function ahead of time according to DataType - * to avoid pattern matching and branching costs per row. - */ - private[orc] def getWritableWrapper(dataType: DataType): Any => Any = dataType match { - case NullType => _ => null - - case BooleanType => withNullSafe(o => new BooleanWritable(o.asInstanceOf[Boolean])) - - case ByteType => withNullSafe(o => new ByteWritable(o.asInstanceOf[Byte])) - case ShortType => withNullSafe(o => new ShortWritable(o.asInstanceOf[Short])) - case IntegerType => withNullSafe(o => new IntWritable(o.asInstanceOf[Int])) - case LongType => withNullSafe(o => new LongWritable(o.asInstanceOf[Long])) - - case FloatType => withNullSafe(o => new FloatWritable(o.asInstanceOf[Float])) - case DoubleType => withNullSafe(o => new DoubleWritable(o.asInstanceOf[Double])) - - case StringType => withNullSafe(o => new Text(o.asInstanceOf[UTF8String].getBytes)) - - case BinaryType => withNullSafe(o => new BytesWritable(o.asInstanceOf[Array[Byte]])) - - case DateType => - withNullSafe(o => new DateWritable(DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))) - case TimestampType => - withNullSafe { o => - val us = o.asInstanceOf[Long] - var seconds = us / DateTimeUtils.MICROS_PER_SECOND - var micros = us % DateTimeUtils.MICROS_PER_SECOND - if (micros < 0) { - micros += DateTimeUtils.MICROS_PER_SECOND - seconds -= 1 - } - val t = new OrcTimestamp(seconds * 1000) - t.setNanos(micros.toInt * 1000) - t - } - - case _: DecimalType => - withNullSafe { o => - new HiveDecimalWritable(HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - case st: StructType => - withNullSafe(o => convertInternalRowToOrcStruct(o.asInstanceOf[InternalRow], st)) - - case ArrayType(et, _) => - withNullSafe { o => - val data = o.asInstanceOf[ArrayData] - val list = createOrcValue(dataType) - for (i <- 0 until data.numElements()) { - val d = data.get(i, et) - val v = getWritableWrapper(et)(d).asInstanceOf[WritableComparable[_]] - list.asInstanceOf[OrcList[WritableComparable[_]]].add(v) - } - list - } - - case MapType(keyType, valueType, _) => - withNullSafe { o => - val keyWrapper = getWritableWrapper(keyType) - val valueWrapper = getWritableWrapper(valueType) - val data = o.asInstanceOf[MapData] - val map = createOrcValue(dataType) - .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - data.foreach(keyType, valueType, { case (k, v) => - map.put( - keyWrapper(k).asInstanceOf[WritableComparable[_]], - valueWrapper(v).asInstanceOf[WritableComparable[_]]) - }) - map - } - - case udt: UserDefinedType[_] => - withNullSafe { o => - val udtRow = new SpecificInternalRow(Seq(udt.sqlType)) - udtRow(0) = o - convertInternalRowToOrcStruct( - udtRow, - StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0) - } - - case _ => - throw new UnsupportedOperationException(s"$dataType is not supported yet.") - } } From 9d18834772e8287172075ec3c5181ff7dc42a4df Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 Nov 2017 10:06:10 -0800 Subject: [PATCH 03/20] Recover OrcFileFormat back, avoid function serialization, add TODO. --- .../datasources/orc/OrcFileFormat.scala | 4 +- .../execution/datasources/orc/OrcUtils.scala | 6 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 320 ++++++++++++++++++ 3 files changed, 326 insertions(+), 4 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index ea9fb3bd377c..d3564dd7bf74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -142,13 +142,13 @@ class OrcFileFormat val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val resolver = sparkSession.sessionState.conf.resolver + val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedConf.value.value val maybeMissingSchema = OrcUtils.getMissingSchema( - resolver, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) if (maybeMissingSchema.isEmpty) { Iterator.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 747a759772e2..74ea2834e79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -28,7 +28,7 @@ import org.apache.orc.{OrcFile, TypeDescription} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -71,6 +71,7 @@ object OrcUtils extends Logging { private[orc] def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { val conf = sparkSession.sessionState.newHadoopConf() + // TODO: We need to support merge schema. Please see SPARK-11412. files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] @@ -90,11 +91,12 @@ object OrcUtils extends Logging { * Return a missing schema in a give ORC file. */ private[orc] def getMissingSchema( - resolver: Resolver, + isCaseSensitive: Boolean, dataSchema: StructType, partitionSchema: StructType, file: Path, conf: Configuration): Option[StructType] = { + val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution try { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala new file mode 100644 index 000000000000..3b33a9ff082f --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.net.URI +import java.util.Properties + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.orc.OrcConf.COMPRESS + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * `FileFormat` for reading ORC files. If this is moved or renamed, please update + * `DataSource`'s backwardCompatibilityMap. + */ +class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable { + + override def shortName(): String = "orc" + + override def toString: String = "ORC" + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcFileOperator.readSchema( + files.map(_.getPath.toString), + Some(sparkSession.sessionState.newHadoopConf()) + ) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + + val configuration = job.getConfiguration + + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + configuration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + // Sets pushed predicates + OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => + hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) + hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val conf = broadcastedHadoopConf.value.value + + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this + // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file + // using the given physical schema. Instead, we simply return an empty iterator. + val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + if (isEmptyFile) { + Iterator.empty + } else { + OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) + + val orcRecordReader = { + val job = Job.getInstance(conf) + FileInputFormat.setInputPaths(job, file.filePath) + + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), file.start, file.length, Array.empty + ) + // Custom OrcRecordReader is used to get + // ObjectInspector during recordReader creation itself and can + // avoid NameNode call in unwrapOrcStructs per file. + // Specifically would be helpful for partitioned datasets. + val orcReader = OrcFile.createReader( + new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + } + + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + + // Unwraps `OrcStruct`s to `UnsafeRow`s + OrcFileFormat.unwrapOrcStructs( + conf, + dataSchema, + requiredSchema, + Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), + recordsIterator) + } + } + } +} + +private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) + extends HiveInspectors { + + def serialize(row: InternalRow): Writable = { + wrapOrcStruct(cachedOrcStruct, structOI, row) + serializer.serialize(cachedOrcStruct, structOI) + } + + private[this] val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map(_.dataType.catalogString).mkString(":")) + + val serde = new OrcSerde + serde.initialize(conf, table) + serde + } + + // Object inspector converted from the schema of the relation to be serialized. + private[this] val structOI = { + val typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(dataSchema.catalogString) + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] + } + + private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format + private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { + case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) + } + + private[this] def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs + var i = 0 + val size = fieldRefs.size + while (i < size) { + + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrappers(i)(row.get(i, dataSchema(i).dataType)) + ) + i += 1 + } + } +} + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + new OrcOutputFormat().getRecordWriter( + new Path(path).getFileSystem(context.getConfiguration), + context.getConfiguration.asInstanceOf[JobConf], + path, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] + } + + override def write(row: InternalRow): Unit = { + recordWriter.write(NullWritable.get(), serializer.serialize(row)) + } + + override def close(): Unit = { + if (recordWriterInstantiated) { + recordWriter.close(Reporter.NULL) + } + } +} + +private[orc] object OrcFileFormat extends HiveInspectors { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + + def unwrapOrcStructs( + conf: Configuration, + dataSchema: StructType, + requiredSchema: StructType, + maybeStructOI: Option[StructObjectInspector], + iterator: Iterator[Writable]): Iterator[InternalRow] = { + val deserializer = new OrcSerde + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(requiredSchema) + + def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { + val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { + case (field, ordinal) => + var ref = oi.getStructFieldRef(field.name) + if (ref == null) { + ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) + } + ref -> ordinal + }.unzip + + val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) + + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + val length = fieldRefs.length + while (i < length) { + val fieldRef = fieldRefs(i) + val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + unsafeProjection(mutableRow) + } + } + + maybeStructOI.map(unwrap).getOrElse(Iterator.empty) + } + + def setRequiredColumns( + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) + val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) + } +} From 6971cdf7d0330b933c09d11fe415f3aada10609d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 Nov 2017 10:44:03 -0800 Subject: [PATCH 04/20] Simplify `isSearchableType` --- .../spark/sql/execution/datasources/orc/OrcFilters.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index de920c8180e9..99846d1bd777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -50,11 +50,12 @@ private[orc] object OrcFilters { /** * Return true if this is a searchable type in ORC. + * Both CharType and VarcharType are cleaned at AstBuilder. */ private def isSearchableType(dataType: DataType) = dataType match { - case ByteType | ShortType | FloatType | DoubleType => true - case IntegerType | LongType | StringType | BooleanType => true - case TimestampType | _: DecimalType => true + // TODO: SPARK-21787 Support for pushing down filters for DateType in ORC + case BinaryType | DateType => false + case _: AtomicType => true case _ => false } From b3734957fffbe7636718d72e0ff75b5059dbcb7e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 Nov 2017 15:53:11 -0800 Subject: [PATCH 05/20] Avoid boxing for primitive types. --- .../datasources/orc/OrcDeserializer.scala | 61 +++++++++++++++---- .../datasources/orc/OrcFilters.scala | 2 + 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index b934aea79c5e..20604d805f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -37,11 +37,11 @@ private[orc] class OrcDeserializer( private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - private[this] val valueWrappers = requiredSchema.fields.map(f => getValueWrapper(f.dataType)) + private[this] val unwrappers = requiredSchema.fields.map(f => unwrapperFor(f.dataType)) def deserialize(writable: OrcStruct): InternalRow = { convertOrcStructToInternalRow(writable, dataSchema, requiredSchema, - maybeMissingSchema, Some(valueWrappers), Some(mutableRow)) + maybeMissingSchema, Some(unwrappers), Some(mutableRow)) } /** @@ -53,11 +53,11 @@ private[orc] class OrcDeserializer( dataSchema: StructType, requiredSchema: StructType, missingSchema: Option[StructType] = None, - valueWrappers: Option[Seq[Any => Any]] = None, + valueUnwrappers: Option[Seq[(Any, InternalRow, Int) => Unit]] = None, internalRow: Option[InternalRow] = None): InternalRow = { val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) - val wrappers = - valueWrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(getValueWrapper).toSeq) + val unwrappers = + valueUnwrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(unwrapperFor).toSeq) var i = 0 val len = requiredSchema.length val names = orcStruct.getSchema.getFieldNames @@ -75,7 +75,7 @@ private[orc] class OrcDeserializer( if (writable == null) { mutableRow.setNullAt(i) } else { - mutableRow(i) = wrappers(i)(writable) + unwrappers(i)(writable, mutableRow, i) } i += 1 } @@ -90,7 +90,7 @@ private[orc] class OrcDeserializer( * Builds a catalyst-value return function ahead of time according to DataType * to avoid pattern matching and branching costs per row. */ - private[this] def getValueWrapper(dataType: DataType): Any => Any = dataType match { + private[this] def getValueUnwrapper(dataType: DataType): Any => Any = dataType match { case NullType => _ => null case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) @@ -138,7 +138,7 @@ private[orc] class OrcDeserializer( case ArrayType(elementType, _) => withNullSafe { o => - val wrapper = getValueWrapper(elementType) + val wrapper = getValueUnwrapper(elementType) val data = new ArrayBuffer[Any] o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => data += wrapper(x) @@ -148,8 +148,8 @@ private[orc] class OrcDeserializer( case MapType(keyType, valueType, _) => withNullSafe { o => - val keyWrapper = getValueWrapper(keyType) - val valueWrapper = getValueWrapper(valueType) + val keyWrapper = getValueUnwrapper(keyType) + val valueWrapper = getValueUnwrapper(valueType) val map = new java.util.TreeMap[Any, Any] o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] .entrySet().asScala.foreach { entry => @@ -159,9 +159,48 @@ private[orc] class OrcDeserializer( } case udt: UserDefinedType[_] => - withNullSafe { o => getValueWrapper(udt.sqlType)(o) } + withNullSafe { o => getValueUnwrapper(udt.sqlType)(o) } case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") } + + private[this] def unwrapperFor(dataType: DataType): (Any, InternalRow, Int) => Unit = + dataType match { + case NullType => + (value: Any, row: InternalRow, ordinal: Int) => row.setNullAt(ordinal) + + case BooleanType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case _ => + val unwrapper = getValueUnwrapper(dataType) + (value: Any, row: InternalRow, ordinal: Int) => + row(ordinal) = unwrapper(value) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 99846d1bd777..4cd72bf3c6d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -43,7 +43,9 @@ private[orc] object OrcFilters { } yield filter for { + // Combines all convertible filters using `And` to produce a single conjunction conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() } From de8b509c431d3bafe7ae0c3157796422fa9ba71e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 11 Nov 2017 16:03:58 -0800 Subject: [PATCH 06/20] Use getMissingColumnNames instead of getMissingSchema --- .../execution/datasources/orc/OrcDeserializer.scala | 8 ++++---- .../execution/datasources/orc/OrcFileFormat.scala | 10 +++++----- .../sql/execution/datasources/orc/OrcUtils.scala | 13 +++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 20604d805f8e..b6bf5371cecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String private[orc] class OrcDeserializer( dataSchema: StructType, requiredSchema: StructType, - maybeMissingSchema: Option[StructType]) { + maybeMissingSchemaColumnNames: Option[Seq[String]]) { private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) @@ -41,7 +41,7 @@ private[orc] class OrcDeserializer( def deserialize(writable: OrcStruct): InternalRow = { convertOrcStructToInternalRow(writable, dataSchema, requiredSchema, - maybeMissingSchema, Some(unwrappers), Some(mutableRow)) + maybeMissingSchemaColumnNames, Some(unwrappers), Some(mutableRow)) } /** @@ -52,7 +52,7 @@ private[orc] class OrcDeserializer( orcStruct: OrcStruct, dataSchema: StructType, requiredSchema: StructType, - missingSchema: Option[StructType] = None, + missingColumnNames: Option[Seq[String]] = None, valueUnwrappers: Option[Seq[(Any, InternalRow, Int) => Unit]] = None, internalRow: Option[InternalRow] = None): InternalRow = { val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) @@ -63,7 +63,7 @@ private[orc] class OrcDeserializer( val names = orcStruct.getSchema.getFieldNames while (i < len) { val name = requiredSchema(i).name - val writable = if (missingSchema.isEmpty || missingSchema.get.getFieldIndex(name).isEmpty) { + val writable = if (missingColumnNames.isEmpty || !missingColumnNames.contains(name)) { if (names.contains(name)) { orcStruct.getFieldValue(name) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index d3564dd7bf74..c6deb3cc7fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -147,14 +147,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value - val maybeMissingSchema = OrcUtils.getMissingSchema( + val maybeMissingColumnNames = OrcUtils.getMissingColumnNames( isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) - if (maybeMissingSchema.isEmpty) { + if (maybeMissingColumnNames.isEmpty) { Iterator.empty } else { - val missingSchema = maybeMissingSchema.get + val missingColumnNames = maybeMissingColumnNames.get val columns = requiredSchema - .filter(f => missingSchema.getFieldIndex(f.name).isEmpty) + .filter(f => !missingColumnNames.contains(f.name)) .map(f => dataSchema.fieldIndex(f.name)).mkString(",") conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, columns) @@ -169,7 +169,7 @@ class OrcFileFormat Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) val unsafeProjection = UnsafeProjection.create(requiredSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, maybeMissingSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, maybeMissingColumnNames) iter.map(value => unsafeProjection(deserializer.deserialize(value))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 74ea2834e79a..cc0ebf80e8e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.IOException import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -88,14 +89,14 @@ object OrcUtils extends Logging { } /** - * Return a missing schema in a give ORC file. + * Return missing column names in a give ORC file. */ - private[orc] def getMissingSchema( + private[orc] def getMissingColumnNames( isCaseSensitive: Boolean, dataSchema: StructType, partitionSchema: StructType, file: Path, - conf: Configuration): Option[StructType] = { + conf: Configuration): Option[Seq[String]] = { val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution try { val fs = file.getFileSystem(conf) @@ -116,15 +117,15 @@ object OrcUtils extends Logging { schema } - var missingSchema = new StructType + val missingColumnNames = new ArrayBuffer[String] if (dataSchema.length > orcSchema.getFieldNames.size) { dataSchema.filter(x => partitionSchema.getFieldIndex(x.name).isEmpty).foreach { f => if (!orcSchema.getFieldNames.asScala.exists(resolver(_, f.name))) { - missingSchema = missingSchema.add(f) + missingColumnNames += f.name } } } - Some(missingSchema) + Some(missingColumnNames) } } catch { case _: IOException => None From 726406f04c514aee9392af75f29c48d15d5abb56 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 11 Nov 2017 16:21:44 -0800 Subject: [PATCH 07/20] Move withNullSafe to OrcUtils --- .../sql/execution/datasources/orc/OrcDeserializer.scala | 5 +---- .../spark/sql/execution/datasources/orc/OrcSerializer.scala | 6 +----- .../spark/sql/execution/datasources/orc/OrcUtils.scala | 4 ++++ .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 5 +---- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index b6bf5371cecc..6deeaf0408a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -27,6 +27,7 @@ import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.orc.OrcUtils.withNullSafe import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -82,10 +83,6 @@ private[orc] class OrcDeserializer( mutableRow } - private[this] def withNullSafe(f: Any => Any): Any => Any = { - input => if (input == null) null else f(input) - } - /** * Builds a catalyst-value return function ahead of time according to DataType * to avoid pattern matching and branching costs per row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 9d605f1fb309..de61dff54b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -25,7 +25,7 @@ import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.orc.OrcUtils.getTypeDescription +import org.apache.spark.sql.execution.datasources.orc.OrcUtils.{getTypeDescription, withNullSafe} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -72,10 +72,6 @@ private[orc] class OrcSerializer(dataSchema: StructType) { orcStruct } - private[this] def withNullSafe(f: Any => Any): Any => Any = { - input => if (input == null) null else f(input) - } - /** * Builds a WritableComparable-return function ahead of time according to DataType * to avoid pattern matching and branching costs per row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index cc0ebf80e8e2..ffe70932acb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -42,6 +42,10 @@ object OrcUtils extends Logging { "ZLIB" -> ".zlib", "LZO" -> ".lzo") + def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4dec2f71b8a5..7e6f89932b74 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.orc.OrcUtils.withNullSafe import org.apache.spark.sql.types import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -256,10 +257,6 @@ private[hive] trait HiveInspectors { case _ => false } - private def withNullSafe(f: Any => Any): Any => Any = { - input => if (input == null) null else f(input) - } - /** * Wraps with Hive types based on object inspector. */ From 40974577c76db5137947ccf354e43c97faf340bf Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 11 Nov 2017 23:15:23 -0800 Subject: [PATCH 08/20] Handle top-level columns in a while-loop and split the logic for Struct field --- .../datasources/orc/OrcDeserializer.scala | 48 ++++++++++++------- .../datasources/orc/OrcFileFormat.scala | 2 +- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 6deeaf0408a4..993a2d70161a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -34,15 +34,35 @@ import org.apache.spark.unsafe.types.UTF8String private[orc] class OrcDeserializer( dataSchema: StructType, requiredSchema: StructType, - maybeMissingSchemaColumnNames: Option[Seq[String]]) { + missingColumnNames: Seq[String]) { private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) private[this] val unwrappers = requiredSchema.fields.map(f => unwrapperFor(f.dataType)) - def deserialize(writable: OrcStruct): InternalRow = { - convertOrcStructToInternalRow(writable, dataSchema, requiredSchema, - maybeMissingSchemaColumnNames, Some(unwrappers), Some(mutableRow)) + def deserialize(orcStruct: OrcStruct): InternalRow = { + var i = 0 + val len = requiredSchema.length + val names = orcStruct.getSchema.getFieldNames + while (i < len) { + val name = requiredSchema(i).name + val writable = if (missingColumnNames.contains(name)) { + null + } else { + if (names.contains(name)) { + orcStruct.getFieldValue(name) + } else { + orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) + } + } + if (writable == null) { + mutableRow.setNullAt(i) + } else { + unwrappers(i)(writable, mutableRow, i) + } + i += 1 + } + mutableRow } /** @@ -52,26 +72,18 @@ private[orc] class OrcDeserializer( private[this] def convertOrcStructToInternalRow( orcStruct: OrcStruct, dataSchema: StructType, - requiredSchema: StructType, - missingColumnNames: Option[Seq[String]] = None, - valueUnwrappers: Option[Seq[(Any, InternalRow, Int) => Unit]] = None, - internalRow: Option[InternalRow] = None): InternalRow = { - val mutableRow = internalRow.getOrElse(new SpecificInternalRow(requiredSchema.map(_.dataType))) - val unwrappers = - valueUnwrappers.getOrElse(requiredSchema.fields.map(_.dataType).map(unwrapperFor).toSeq) + requiredSchema: StructType): InternalRow = { + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unwrappers = requiredSchema.fields.map(_.dataType).map(unwrapperFor).toSeq var i = 0 val len = requiredSchema.length val names = orcStruct.getSchema.getFieldNames while (i < len) { val name = requiredSchema(i).name - val writable = if (missingColumnNames.isEmpty || !missingColumnNames.contains(name)) { - if (names.contains(name)) { - orcStruct.getFieldValue(name) - } else { - orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) - } + val writable = if (names.contains(name)) { + orcStruct.getFieldValue(name) } else { - null + orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) } if (writable == null) { mutableRow.setNullAt(i) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index c6deb3cc7fac..9809e6ee984d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -169,7 +169,7 @@ class OrcFileFormat Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) val unsafeProjection = UnsafeProjection.create(requiredSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, maybeMissingColumnNames) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, missingColumnNames) iter.map(value => unsafeProjection(deserializer.deserialize(value))) } } From 8e0d392602ac30e938a06a7414ea2058e046fedc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 12 Nov 2017 00:35:09 -0800 Subject: [PATCH 09/20] fix --- .../datasources/orc/OrcDeserializer.scala | 86 +++++++++--------- .../datasources/orc/OrcSerializer.scala | 90 ++++++++++++++----- 2 files changed, 112 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 993a2d70161a..4c68df241bb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -38,13 +38,14 @@ private[orc] class OrcDeserializer( private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - private[this] val unwrappers = requiredSchema.fields.map(f => unwrapperFor(f.dataType)) + private[this] val length = requiredSchema.length + + private[this] val unwrappers = requiredSchema.map(_.dataType).map(unwrapperFor).toArray def deserialize(orcStruct: OrcStruct): InternalRow = { var i = 0 - val len = requiredSchema.length val names = orcStruct.getSchema.getFieldNames - while (i < len) { + while (i < length) { val name = requiredSchema(i).name val writable = if (missingColumnNames.contains(name)) { null @@ -65,6 +66,46 @@ private[orc] class OrcDeserializer( mutableRow } + private[this] def unwrapperFor(dataType: DataType): (Any, InternalRow, Int) => Unit = + dataType match { + case NullType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setNullAt(ordinal) + + case BooleanType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => + (value: Any, row: InternalRow, ordinal: Int) => + row.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case _ => + val unwrapper = getValueUnwrapper(dataType) + (value: Any, row: InternalRow, ordinal: Int) => + row(ordinal) = unwrapper(value) + } + /** * Convert Apache ORC OrcStruct to Apache Spark InternalRow. * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. @@ -173,43 +214,4 @@ private[orc] class OrcDeserializer( case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") } - - private[this] def unwrapperFor(dataType: DataType): (Any, InternalRow, Int) => Unit = - dataType match { - case NullType => - (value: Any, row: InternalRow, ordinal: Int) => row.setNullAt(ordinal) - - case BooleanType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) - - case ByteType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setByte(ordinal, value.asInstanceOf[ByteWritable].get) - - case ShortType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setShort(ordinal, value.asInstanceOf[ShortWritable].get) - - case IntegerType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setInt(ordinal, value.asInstanceOf[IntWritable].get) - - case LongType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setLong(ordinal, value.asInstanceOf[LongWritable].get) - - case FloatType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) - - case DoubleType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) - - case _ => - val unwrapper = getValueUnwrapper(dataType) - (value: Any, row: InternalRow, ordinal: Int) => - row(ordinal) = unwrapper(value) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index de61dff54b2a..8b12d8c0801e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -23,7 +23,7 @@ import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.orc.OrcUtils.{getTypeDescription, withNullSafe} import org.apache.spark.sql.types._ @@ -31,14 +31,62 @@ import org.apache.spark.unsafe.types.UTF8String private[orc] class OrcSerializer(dataSchema: StructType) { - private[this] lazy val orcStruct: OrcStruct = - createOrcValue(dataSchema).asInstanceOf[OrcStruct] + private[this] lazy val orcStruct: OrcStruct = createOrcValue(dataSchema).asInstanceOf[OrcStruct] - private[this] val writableWrappers = - dataSchema.fields.map(f => getWritableWrapper(f.dataType)) + private[this] lazy val length = dataSchema.length + + private[this] val writers = dataSchema.map(_.dataType).map(makeWriter).toArray def serialize(row: InternalRow): OrcStruct = { - convertInternalRowToOrcStruct(row, dataSchema, Some(writableWrappers), Some(orcStruct)) + var i = 0 + while (i < length) { + if (row.isNullAt(i)) { + orcStruct.setFieldValue(i, null) + } else { + writers(i)(row, i) + } + i += 1 + } + orcStruct + } + + private[this] def makeWriter(dataType: DataType): (SpecializedGetters, Int) => Unit = { + dataType match { + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new BooleanWritable(row.getBoolean(ordinal))) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new ByteWritable(row.getByte(ordinal))) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new ShortWritable(row.getShort(ordinal))) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new IntWritable(row.getInt(ordinal))) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new LongWritable(row.getLong(ordinal))) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new FloatWritable(row.getFloat(ordinal))) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + orcStruct.setFieldValue(ordinal, new DoubleWritable(row.getDouble(ordinal))) + + case _ => + val wrapper = getWritableWrapper(dataType) + (row: SpecializedGetters, ordinal: Int) => { + val value = wrapper(row.get(ordinal, dataType)).asInstanceOf[WritableComparable[_]] + orcStruct.setFieldValue(ordinal, value) + } + } } /** @@ -50,24 +98,22 @@ private[orc] class OrcSerializer(dataSchema: StructType) { /** * Convert Apache Spark InternalRow to Apache ORC OrcStruct. */ - private[this] def convertInternalRowToOrcStruct( - row: InternalRow, - schema: StructType, - valueWrappers: Option[Seq[Any => Any]] = None, - struct: Option[OrcStruct] = None): OrcStruct = { - val wrappers = - valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getWritableWrapper).toSeq) - val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct]) - - for (schemaIndex <- 0 until schema.length) { - val fieldType = schema(schemaIndex).dataType - if (row.isNullAt(schemaIndex)) { - orcStruct.setFieldValue(schemaIndex, null) + private[this] def convertInternalRowToOrcStruct(row: InternalRow, schema: StructType) = { + val wrappers = schema.map(_.dataType).map(getWritableWrapper).toArray + val orcStruct = createOrcValue(schema).asInstanceOf[OrcStruct] + + var i = 0 + val length = schema.length + while (i < length) { + val fieldType = schema(i).dataType + if (row.isNullAt(i)) { + orcStruct.setFieldValue(i, null) } else { - val field = row.get(schemaIndex, fieldType) - val fieldValue = wrappers(schemaIndex)(field).asInstanceOf[WritableComparable[_]] - orcStruct.setFieldValue(schemaIndex, fieldValue) + val field = row.get(i, fieldType) + val fieldValue = wrappers(i)(field).asInstanceOf[WritableComparable[_]] + orcStruct.setFieldValue(i, fieldValue) } + i += 1 } orcStruct } From 9e3ac1ab60e53f5840df163dc3fc0491652add6b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 12 Nov 2017 16:38:44 -0800 Subject: [PATCH 10/20] Address comments. --- .../execution/datasources/orc/OrcDeserializer.scala | 1 - .../sql/execution/datasources/orc/OrcFileFormat.scala | 4 +++- .../sql/execution/datasources/orc/OrcUtils.scala | 11 +++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 4c68df241bb9..5238a21a7583 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -108,7 +108,6 @@ private[orc] class OrcDeserializer( /** * Convert Apache ORC OrcStruct to Apache Spark InternalRow. - * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. */ private[this] def convertOrcStructToInternalRow( orcStruct: OrcStruct, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 9809e6ee984d..b595c098441f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -93,7 +93,7 @@ class OrcFileFormat val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.getSchemaString(dataSchema)) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) @@ -147,6 +147,8 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. + // In this case, `getMissingColumnNames` returns `None` and we return an empty iterator. val maybeMissingColumnNames = OrcUtils.getMissingColumnNames( isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) if (maybeMissingColumnNames.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index ffe70932acb3..3b28fdeec59d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -83,17 +83,16 @@ object OrcUtils extends Logging { } } - private[orc] def getSchemaString(schema: StructType): String = { - schema.fields.map(f => s"${f.name}:${f.dataType.catalogString}").mkString("struct<", ",", ">") - } - private[orc] def getTypeDescription(dataType: DataType) = dataType match { - case st: StructType => TypeDescription.fromString(getSchemaString(st)) + case st: StructType => TypeDescription.fromString(st.catalogString) case _ => TypeDescription.fromString(dataType.catalogString) } /** - * Return missing column names in a give ORC file. + * Return missing column names in a give ORC file or `None`. + * `None` is returned for the following cases. OrcFileFormat will handle as empty iterators. + * - Some old empty ORC files always have an empty schema stored in their footer. (SPARK-8501) + * - Other IOExceptions during reading schema. */ private[orc] def getMissingColumnNames( isCaseSensitive: Boolean, From a3ebfbfcaf2a975b4045629e5d20238da51b4be6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 12 Nov 2017 17:07:52 -0800 Subject: [PATCH 11/20] fix --- .../execution/datasources/orc/OrcUtils.scala | 73 ++++++++----------- 1 file changed, 32 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 3b28fdeec59d..5c24cda5ac33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -58,18 +58,14 @@ object OrcUtils extends Logging { } private[orc] def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { - try { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) - } - } catch { - case _: IOException => None + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) } } @@ -90,9 +86,8 @@ object OrcUtils extends Logging { /** * Return missing column names in a give ORC file or `None`. - * `None` is returned for the following cases. OrcFileFormat will handle as empty iterators. - * - Some old empty ORC files always have an empty schema stored in their footer. (SPARK-8501) - * - Other IOExceptions during reading schema. + * Some old empty ORC files always have an empty schema stored in their footer. (SPARK-8501) + * In that case, `None` is returned and OrcFileFormat will handle as empty iterators. */ private[orc] def getMissingColumnNames( isCaseSensitive: Boolean, @@ -101,37 +96,33 @@ object OrcUtils extends Logging { file: Path, conf: Configuration): Option[Seq[String]] = { val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution - try { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - val orcSchema = if (schema.getFieldNames.asScala.forall(_.startsWith("_col"))) { - logInfo("Recover ORC schema with data schema") - var schemaString = schema.toString - dataSchema.zipWithIndex.foreach { case (field: StructField, index: Int) => - schemaString = schemaString.replace(s"_col$index:", s"${field.name}:") - } - TypeDescription.fromString(schemaString) - } else { - schema + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + val orcSchema = if (schema.getFieldNames.asScala.forall(_.startsWith("_col"))) { + logInfo("Recover ORC schema with data schema") + var schemaString = schema.toString + dataSchema.zipWithIndex.foreach { case (field: StructField, index: Int) => + schemaString = schemaString.replace(s"_col$index:", s"${field.name}:") } + TypeDescription.fromString(schemaString) + } else { + schema + } - val missingColumnNames = new ArrayBuffer[String] - if (dataSchema.length > orcSchema.getFieldNames.size) { - dataSchema.filter(x => partitionSchema.getFieldIndex(x.name).isEmpty).foreach { f => - if (!orcSchema.getFieldNames.asScala.exists(resolver(_, f.name))) { - missingColumnNames += f.name - } + val missingColumnNames = new ArrayBuffer[String] + if (dataSchema.length > orcSchema.getFieldNames.size) { + dataSchema.filter(x => partitionSchema.getFieldIndex(x.name).isEmpty).foreach { f => + if (!orcSchema.getFieldNames.asScala.exists(resolver(_, f.name))) { + missingColumnNames += f.name } } - Some(missingColumnNames) } - } catch { - case _: IOException => None + Some(missingColumnNames) } } } From cc40fba4e4f12d7f65dd7f426fc20f3b5c2a7b6e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 21 Nov 2017 22:29:48 -0800 Subject: [PATCH 12/20] Move out column name handling logic. --- .../datasources/orc/OrcDeserializer.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 5238a21a7583..35ea10ecb948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -40,14 +40,19 @@ private[orc] class OrcDeserializer( private[this] val length = requiredSchema.length - private[this] val unwrappers = requiredSchema.map(_.dataType).map(unwrapperFor).toArray + private[this] val unwrappers = requiredSchema.map { f => + if (missingColumnNames.contains(f.name)) { + (value: Any, row: InternalRow, ordinal: Int) => row.setNullAt(ordinal) + } else { + unwrapperFor(f.dataType) + } + }.toArray def deserialize(orcStruct: OrcStruct): InternalRow = { - var i = 0 val names = orcStruct.getSchema.getFieldNames - while (i < length) { - val name = requiredSchema(i).name - val writable = if (missingColumnNames.contains(name)) { + val fieldRefs = requiredSchema.map { f => + val name = f.name + if (missingColumnNames.contains(name)) { null } else { if (names.contains(name)) { @@ -56,6 +61,11 @@ private[orc] class OrcDeserializer( orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) } } + }.toArray + + var i = 0 + while (i < length) { + val writable = fieldRefs(i) if (writable == null) { mutableRow.setNullAt(i) } else { From f482179c48dfad970fd85840be86ca6f4534888a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 28 Nov 2017 10:03:18 -0800 Subject: [PATCH 13/20] Use Updater like Parquet. --- .../datasources/orc/OrcDeserializer.scala | 330 +++++++++++------- .../datasources/orc/OrcFileFormat.scala | 4 +- .../datasources/orc/OrcFilters.scala | 29 +- .../datasources/orc/OrcSerializer.scala | 2 +- 4 files changed, 234 insertions(+), 131 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 35ea10ecb948..6a051684311c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -27,7 +27,6 @@ import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.orc.OrcUtils.withNullSafe import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -36,16 +35,17 @@ private[orc] class OrcDeserializer( requiredSchema: StructType, missingColumnNames: Seq[String]) { - private[this] val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + private[this] val currentRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) private[this] val length = requiredSchema.length - private[this] val unwrappers = requiredSchema.map { f => - if (missingColumnNames.contains(f.name)) { - (value: Any, row: InternalRow, ordinal: Int) => row.setNullAt(ordinal) - } else { - unwrapperFor(f.dataType) - } + private[this] val fieldConverters: Array[Converter] = requiredSchema.zipWithIndex.map { + case (f, ordinal) => + if (missingColumnNames.contains(f.name)) { + null + } else { + newConverter(f.dataType, new RowUpdater(currentRow, ordinal)) + } }.toArray def deserialize(orcStruct: OrcStruct): InternalRow = { @@ -67,160 +67,238 @@ private[orc] class OrcDeserializer( while (i < length) { val writable = fieldRefs(i) if (writable == null) { - mutableRow.setNullAt(i) + currentRow.setNullAt(i) } else { - unwrappers(i)(writable, mutableRow, i) + fieldConverters(i).set(writable) } i += 1 } - mutableRow + currentRow } - private[this] def unwrapperFor(dataType: DataType): (Any, InternalRow, Int) => Unit = + private[this] def newConverter(dataType: DataType, updater: OrcDataUpdater): Converter = dataType match { case NullType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setNullAt(ordinal) + new Converter { + override def set(value: Any): Unit = updater.setNullAt() + } case BooleanType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + new Converter { + override def set(value: Any): Unit = + updater.setBoolean(value.asInstanceOf[BooleanWritable].get) + } case ByteType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + new Converter { + override def set(value: Any): Unit = updater.setByte(value.asInstanceOf[ByteWritable].get) + } case ShortType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + new Converter { + override def set(value: Any): Unit = + updater.setShort(value.asInstanceOf[ShortWritable].get) + } case IntegerType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setInt(ordinal, value.asInstanceOf[IntWritable].get) + new Converter { + override def set(value: Any): Unit = updater.setInt(value.asInstanceOf[IntWritable].get) + } case LongType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setLong(ordinal, value.asInstanceOf[LongWritable].get) + new Converter { + override def set(value: Any): Unit = updater.setLong(value.asInstanceOf[LongWritable].get) + } case FloatType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + new Converter { + override def set(value: Any): Unit = + updater.setFloat(value.asInstanceOf[FloatWritable].get) + } case DoubleType => - (value: Any, row: InternalRow, ordinal: Int) => - row.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + new Converter { + override def set(value: Any): Unit = + updater.setDouble(value.asInstanceOf[DoubleWritable].get) + } - case _ => - val unwrapper = getValueUnwrapper(dataType) - (value: Any, row: InternalRow, ordinal: Int) => - row(ordinal) = unwrapper(value) - } + case StringType => + new Converter { + override def set(value: Any): Unit = + updater.set(UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) + } - /** - * Convert Apache ORC OrcStruct to Apache Spark InternalRow. - */ - private[this] def convertOrcStructToInternalRow( - orcStruct: OrcStruct, - dataSchema: StructType, - requiredSchema: StructType): InternalRow = { - val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - val unwrappers = requiredSchema.fields.map(_.dataType).map(unwrapperFor).toSeq - var i = 0 - val len = requiredSchema.length - val names = orcStruct.getSchema.getFieldNames - while (i < len) { - val name = requiredSchema(i).name - val writable = if (names.contains(name)) { - orcStruct.getFieldValue(name) - } else { - orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) - } - if (writable == null) { - mutableRow.setNullAt(i) - } else { - unwrappers(i)(writable, mutableRow, i) - } - i += 1 - } - mutableRow - } + case BinaryType => + new Converter { + override def set(value: Any): Unit = { + val binary = value.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + updater.set(bytes) + } + } + + case DateType => + new Converter { + override def set(value: Any): Unit = + updater.set(DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + } - /** - * Builds a catalyst-value return function ahead of time according to DataType - * to avoid pattern matching and branching costs per row. - */ - private[this] def getValueUnwrapper(dataType: DataType): Any => Any = dataType match { - case NullType => _ => null + case TimestampType => + new Converter { + override def set(value: Any): Unit = + updater.set(DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + } - case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) + case DecimalType.Fixed(precision, scale) => + new Converter { + override def set(value: Any): Unit = { + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + updater.set(v) + } + } - case ByteType => withNullSafe(o => o.asInstanceOf[ByteWritable].get) - case ShortType => withNullSafe(o => o.asInstanceOf[ShortWritable].get) - case IntegerType => withNullSafe(o => o.asInstanceOf[IntWritable].get) - case LongType => withNullSafe(o => o.asInstanceOf[LongWritable].get) + case st: StructType => + new Converter { + override def set(value: Any): Unit = { + val orcStruct = value.asInstanceOf[OrcStruct] + val mutableRow = new SpecificInternalRow(st) + val fieldConverters: Array[Converter] = st.zipWithIndex.map { case (f, ordinal) => + if (missingColumnNames.contains(f.name)) { + null + } else { + newConverter(f.dataType, new RowUpdater(mutableRow, ordinal)) + } + }.toArray + + var i = 0 + val length = st.fields.length + while (i < length) { + val name = st(i).name + val writable = orcStruct.getFieldValue(name) + if (writable == null) { + mutableRow.setNullAt(i) + } else { + fieldConverters(i).set(writable) + } + i += 1 + } + updater.set(mutableRow) + } + } - case FloatType => withNullSafe(o => o.asInstanceOf[FloatWritable].get) - case DoubleType => withNullSafe(o => o.asInstanceOf[DoubleWritable].get) + case ArrayType(elementType, _) => + new Converter { + override def set(value: Any): Unit = { + val arrayDataUpdater = new ArrayDataUpdater(updater) + val converter = newConverter(elementType, arrayDataUpdater) + value.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => + if (x == null) { + arrayDataUpdater.set(null) + } else { + converter.set(x) + } + } + arrayDataUpdater.end() + } + } - case StringType => - withNullSafe(o => UTF8String.fromBytes(o.asInstanceOf[Text].copyBytes)) + case MapType(keyType, valueType, _) => + new Converter { + override def set(value: Any): Unit = { + val mapDataUpdater = new MapDataUpdater(keyType, valueType, updater) + mapDataUpdater.set(value) + mapDataUpdater.end() + } + } - case BinaryType => - withNullSafe { o => - val binary = o.asInstanceOf[BytesWritable] - val bytes = new Array[Byte](binary.getLength) - System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) - bytes - } + case udt: UserDefinedType[_] => + new Converter { + override def set(value: Any): Unit = { + val mutableRow = new SpecificInternalRow(new StructType().add("_col1", udt.sqlType)) + val converter = newConverter(udt.sqlType, new RowUpdater(mutableRow, 0)) + converter.set(value) + updater.set(mutableRow.get(0, dataType)) + } + } - case DateType => - withNullSafe(o => DateTimeUtils.fromJavaDate(o.asInstanceOf[DateWritable].get)) - case TimestampType => - withNullSafe(o => DateTimeUtils.fromJavaTimestamp(o.asInstanceOf[OrcTimestamp])) - - case DecimalType.Fixed(precision, scale) => - withNullSafe { o => - val decimal = o.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) - v.changePrecision(precision, scale) - v - } + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } - case _: StructType => - withNullSafe { o => - val structValue = convertOrcStructToInternalRow( - o.asInstanceOf[OrcStruct], - dataType.asInstanceOf[StructType], - dataType.asInstanceOf[StructType]) - structValue - } - case ArrayType(elementType, _) => - withNullSafe { o => - val wrapper = getValueUnwrapper(elementType) - val data = new ArrayBuffer[Any] - o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => - data += wrapper(x) - } - new GenericArrayData(data.toArray) - } + // -------------------------------------------------------------------------- + // Converter and Updaters + // -------------------------------------------------------------------------- - case MapType(keyType, valueType, _) => - withNullSafe { o => - val keyWrapper = getValueUnwrapper(keyType) - val valueWrapper = getValueUnwrapper(valueType) - val map = new java.util.TreeMap[Any, Any] - o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - .entrySet().asScala.foreach { entry => - map.put(keyWrapper(entry.getKey), valueWrapper(entry.getValue)) + trait Converter { + def set(value: Any): Unit + } + + trait OrcDataUpdater { + def setNullAt(): Unit = () + def set(value: Any): Unit = () + def setBoolean(value: Boolean): Unit = set(value) + def setByte(value: Byte): Unit = set(value) + def setShort(value: Short): Unit = set(value) + def setInt(value: Int): Unit = set(value) + def setLong(value: Long): Unit = set(value) + def setDouble(value: Double): Unit = set(value) + def setFloat(value: Float): Unit = set(value) + } + + final class RowUpdater(row: InternalRow, i: Int) extends OrcDataUpdater { + override def set(value: Any): Unit = row(i) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(i, value) + override def setByte(value: Byte): Unit = row.setByte(i, value) + override def setShort(value: Short): Unit = row.setShort(i, value) + override def setInt(value: Int): Unit = row.setInt(i, value) + override def setLong(value: Long): Unit = row.setLong(i, value) + override def setDouble(value: Double): Unit = row.setDouble(i, value) + override def setFloat(value: Float): Unit = row.setFloat(i, value) + } + + final class ArrayDataUpdater(updater: OrcDataUpdater) extends OrcDataUpdater { + private val currentArray: ArrayBuffer[Any] = ArrayBuffer.empty[Any] + + override def set(value: Any): Unit = currentArray += value + + def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + } + + final class MapDataUpdater( + keyType: DataType, + valueType: DataType, + updater: OrcDataUpdater) + extends OrcDataUpdater { + + private val currentKeys: ArrayBuffer[Any] = ArrayBuffer.empty[Any] + private val currentValues: ArrayBuffer[Any] = ArrayBuffer.empty[Any] + + private val keyConverter = newConverter(keyType, new OrcDataUpdater { + override def set(value: Any): Unit = currentKeys += value + }) + private val valueConverter = newConverter(valueType, new OrcDataUpdater { + override def set(value: Any): Unit = currentValues += value + }) + + override def set(value: Any): Unit = { + value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + .entrySet().asScala.foreach { entry => + + assert(entry != null) + keyConverter.set(entry.getKey) + assert(valueConverter != null) + if (entry.getValue == null) { + currentValues += null + } else { + valueConverter.set(entry.getValue) } - ArrayBasedMapData(map.asScala) } + } - case udt: UserDefinedType[_] => - withNullSafe { o => getValueUnwrapper(udt.sqlType)(o) } - - case _ => - throw new UnsupportedOperationException(s"$dataType is not supported yet.") + def end(): Unit = updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b595c098441f..91f69515e5f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -59,10 +59,8 @@ private[sql] object OrcFileFormat { } } -class DefaultSource extends OrcFileFormat - /** - * New ORC File Format based on Apache ORC 1.4.1 and above. + * New ORC File Format based on Apache ORC. */ class OrcFileFormat extends FileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4cd72bf3c6d8..cec256cc1b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -25,7 +25,34 @@ import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ /** - * Utility functions to convert Spark data source filters to ORC filters. + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. */ private[orc] object OrcFilters { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 8b12d8c0801e..bd30b7c31add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -33,7 +33,7 @@ private[orc] class OrcSerializer(dataSchema: StructType) { private[this] lazy val orcStruct: OrcStruct = createOrcValue(dataSchema).asInstanceOf[OrcStruct] - private[this] lazy val length = dataSchema.length + private[this] val length = dataSchema.length private[this] val writers = dataSchema.map(_.dataType).map(makeWriter).toArray From e13dfa3600b1a1be3d659eb79ceb73b4078067f8 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 28 Nov 2017 19:15:55 -0800 Subject: [PATCH 14/20] fix --- .../datasources/orc/OrcDeserializer.scala | 14 ++++++-------- .../execution/datasources/orc/OrcFileFormat.scala | 7 +++---- .../sql/execution/datasources/orc/OrcUtils.scala | 10 +++++----- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 6a051684311c..99967c5911e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -166,11 +166,7 @@ private[orc] class OrcDeserializer( val orcStruct = value.asInstanceOf[OrcStruct] val mutableRow = new SpecificInternalRow(st) val fieldConverters: Array[Converter] = st.zipWithIndex.map { case (f, ordinal) => - if (missingColumnNames.contains(f.name)) { - null - } else { - newConverter(f.dataType, new RowUpdater(mutableRow, ordinal)) - } + newConverter(f.dataType, new RowUpdater(mutableRow, ordinal)) }.toArray var i = 0 @@ -239,7 +235,9 @@ private[orc] class OrcDeserializer( trait OrcDataUpdater { def setNullAt(): Unit = () + def set(value: Any): Unit = () + def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) def setShort(value: Short): Unit = set(value) @@ -250,7 +248,10 @@ private[orc] class OrcDeserializer( } final class RowUpdater(row: InternalRow, i: Int) extends OrcDataUpdater { + override def setNullAt(): Unit = row.setNullAt(i) + override def set(value: Any): Unit = row(i) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(i, value) override def setByte(value: Byte): Unit = row.setByte(i, value) override def setShort(value: Short): Unit = row.setShort(i, value) @@ -287,10 +288,7 @@ private[orc] class OrcDeserializer( override def set(value: Any): Unit = { value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] .entrySet().asScala.foreach { entry => - - assert(entry != null) keyConverter.set(entry.getKey) - assert(valueConverter != null) if (entry.getValue == null) { currentValues += null } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 91f69515e5f9..05b84e0b8cdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -69,7 +69,7 @@ class OrcFileFormat override def shortName(): String = "orc" - override def toString: String = "ORC_1.4" + override def toString: String = "ORC" override def hashCode(): Int = getClass.hashCode() @@ -147,12 +147,11 @@ class OrcFileFormat // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. // In this case, `getMissingColumnNames` returns `None` and we return an empty iterator. - val maybeMissingColumnNames = OrcUtils.getMissingColumnNames( + val (isEmptyFile, missingColumnNames) = OrcUtils.getMissingColumnNames( isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) - if (maybeMissingColumnNames.isEmpty) { + if (isEmptyFile) { Iterator.empty } else { - val missingColumnNames = maybeMissingColumnNames.get val columns = requiredSchema .filter(f => !missingColumnNames.contains(f.name)) .map(f => dataSchema.fieldIndex(f.name)).mkString(",") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 5c24cda5ac33..abcd321feee2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -85,23 +85,23 @@ object OrcUtils extends Logging { } /** - * Return missing column names in a give ORC file or `None`. + * Return a pair of isEmptyFile and missing column names in a give ORC file. * Some old empty ORC files always have an empty schema stored in their footer. (SPARK-8501) - * In that case, `None` is returned and OrcFileFormat will handle as empty iterators. + * In that case, isEmptyFile is `true` and missing column names is `None`. */ private[orc] def getMissingColumnNames( isCaseSensitive: Boolean, dataSchema: StructType, partitionSchema: StructType, file: Path, - conf: Configuration): Option[Seq[String]] = { + conf: Configuration): (Boolean, Seq[String]) = { val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val reader = OrcFile.createReader(file, readerOptions) val schema = reader.getSchema if (schema.getFieldNames.size == 0) { - None + (true, Seq.empty[String]) } else { val orcSchema = if (schema.getFieldNames.asScala.forall(_.startsWith("_col"))) { logInfo("Recover ORC schema with data schema") @@ -122,7 +122,7 @@ object OrcUtils extends Logging { } } } - Some(missingColumnNames) + (false, missingColumnNames) } } } From fdab6a7eb3acc63f78889fd31f8f078fca66aa0f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 28 Nov 2017 19:19:09 -0800 Subject: [PATCH 15/20] remove outdate comment. --- .../spark/sql/execution/datasources/orc/OrcFileFormat.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 05b84e0b8cdb..91e6cd88cc70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -146,7 +146,6 @@ class OrcFileFormat val conf = broadcastedConf.value.value // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. - // In this case, `getMissingColumnNames` returns `None` and we return an empty iterator. val (isEmptyFile, missingColumnNames) = OrcUtils.getMissingColumnNames( isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) if (isEmptyFile) { From daef4bac21bf7276578b79a1e05c83a5a407cc0e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 1 Dec 2017 23:07:08 +0800 Subject: [PATCH 16/20] refactor --- .../datasources/orc/OrcDeserializer.scala | 417 ++++++++---------- .../datasources/orc/OrcFileFormat.scala | 19 +- .../datasources/orc/OrcSerializer.scala | 306 +++++++------ .../execution/datasources/orc/OrcUtils.scala | 62 +-- .../spark/sql/hive/HiveInspectors.scala | 5 +- 5 files changed, 381 insertions(+), 428 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 99967c5911e9..4ecc54bd2fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -17,286 +17,227 @@ package org.apache.spark.sql.execution.datasources.orc -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - import org.apache.hadoop.io._ import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[orc] class OrcDeserializer( +/** + * A deserializer to deserialize ORC structs to Spark rows. + */ +class OrcDeserializer( dataSchema: StructType, requiredSchema: StructType, - missingColumnNames: Seq[String]) { - - private[this] val currentRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - - private[this] val length = requiredSchema.length + requestedColIds: Array[Int]) { + + private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + + private val fieldWriters: Array[WritableComparable[_] => Unit] = { + requiredSchema.zipWithIndex + // The value of missing columns are always null, do not need writers. + .filterNot { case (_, index) => requestedColIds(index) == -1 } + .map { case (f, index) => + val writer = newWriter(f.dataType, new RowUpdater(resultRow)) + (value: WritableComparable[_]) => writer(index, value) + }.toArray + } - private[this] val fieldConverters: Array[Converter] = requiredSchema.zipWithIndex.map { - case (f, ordinal) => - if (missingColumnNames.contains(f.name)) { - null - } else { - newConverter(f.dataType, new RowUpdater(currentRow, ordinal)) - } - }.toArray + private val validColIds = requestedColIds.filterNot(_ == -1) def deserialize(orcStruct: OrcStruct): InternalRow = { - val names = orcStruct.getSchema.getFieldNames - val fieldRefs = requiredSchema.map { f => - val name = f.name - if (missingColumnNames.contains(name)) { - null - } else { - if (names.contains(name)) { - orcStruct.getFieldValue(name) - } else { - orcStruct.getFieldValue("_col" + dataSchema.fieldIndex(name)) - } - } - }.toArray - var i = 0 - while (i < length) { - val writable = fieldRefs(i) - if (writable == null) { - currentRow.setNullAt(i) + while (i < validColIds.length) { + val value = orcStruct.getFieldValue(validColIds(i)) + if (value == null) { + resultRow.setNullAt(i) } else { - fieldConverters(i).set(writable) + fieldWriters(i)(value) } i += 1 } - currentRow + resultRow } - private[this] def newConverter(dataType: DataType, updater: OrcDataUpdater): Converter = + /** + * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. + */ + private def newWriter( + dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit = dataType match { - case NullType => - new Converter { - override def set(value: Any): Unit = updater.setNullAt() - } - - case BooleanType => - new Converter { - override def set(value: Any): Unit = - updater.setBoolean(value.asInstanceOf[BooleanWritable].get) - } - - case ByteType => - new Converter { - override def set(value: Any): Unit = updater.setByte(value.asInstanceOf[ByteWritable].get) - } - - case ShortType => - new Converter { - override def set(value: Any): Unit = - updater.setShort(value.asInstanceOf[ShortWritable].get) - } - - case IntegerType => - new Converter { - override def set(value: Any): Unit = updater.setInt(value.asInstanceOf[IntWritable].get) - } - - case LongType => - new Converter { - override def set(value: Any): Unit = updater.setLong(value.asInstanceOf[LongWritable].get) - } - - case FloatType => - new Converter { - override def set(value: Any): Unit = - updater.setFloat(value.asInstanceOf[FloatWritable].get) - } - - case DoubleType => - new Converter { - override def set(value: Any): Unit = - updater.setDouble(value.asInstanceOf[DoubleWritable].get) - } - - case StringType => - new Converter { - override def set(value: Any): Unit = - updater.set(UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) - } - - case BinaryType => - new Converter { - override def set(value: Any): Unit = { - val binary = value.asInstanceOf[BytesWritable] - val bytes = new Array[Byte](binary.getLength) - System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) - updater.set(bytes) + case NullType => (ordinal, _) => + updater.setNullAt(ordinal) + + case BooleanType => (ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => (ordinal, value) => + updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => (ordinal, value) => + updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => (ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => (ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => (ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => (ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case StringType => (ordinal, value) => + updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) + + case BinaryType => (ordinal, value) => + val binary = value.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + updater.set(ordinal, bytes) + + case DateType => (ordinal, value) => + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + + case TimestampType => (ordinal, value) => + updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => (ordinal, value) => + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + updater.set(ordinal, v) + + case st: StructType => (ordinal, value) => + val result = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(result) + val fieldConverters = st.map(_.dataType).map { dt => + newWriter(dt, fieldUpdater) + }.toArray + val orcStruct = value.asInstanceOf[OrcStruct] + + var i = 0 + while (i < st.length) { + val value = orcStruct.getFieldValue(i) + if (value == null) { + result.setNullAt(i) + } else { + fieldConverters(i)(i, value) } + i += 1 } - case DateType => - new Converter { - override def set(value: Any): Unit = - updater.set(DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) - } + updater.set(ordinal, result) - case TimestampType => - new Converter { - override def set(value: Any): Unit = - updater.set(DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) - } - - case DecimalType.Fixed(precision, scale) => - new Converter { - override def set(value: Any): Unit = { - val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) - v.changePrecision(precision, scale) - updater.set(v) - } - } + case ArrayType(elementType, _) => (ordinal, value) => + val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]] + val length = orcArray.size() + val result = createArrayData(elementType, length) + val elementUpdater = new ArrayDataUpdater(result) + val elementConverter = newWriter(elementType, elementUpdater) - case st: StructType => - new Converter { - override def set(value: Any): Unit = { - val orcStruct = value.asInstanceOf[OrcStruct] - val mutableRow = new SpecificInternalRow(st) - val fieldConverters: Array[Converter] = st.zipWithIndex.map { case (f, ordinal) => - newConverter(f.dataType, new RowUpdater(mutableRow, ordinal)) - }.toArray - - var i = 0 - val length = st.fields.length - while (i < length) { - val name = st(i).name - val writable = orcStruct.getFieldValue(name) - if (writable == null) { - mutableRow.setNullAt(i) - } else { - fieldConverters(i).set(writable) - } - i += 1 - } - updater.set(mutableRow) + var i = 0 + while (i < length) { + val value = orcArray.get(i) + if (value == null) { + result.setNullAt(i) + } else { + elementConverter(i, value) } - } - - case ArrayType(elementType, _) => - new Converter { - override def set(value: Any): Unit = { - val arrayDataUpdater = new ArrayDataUpdater(updater) - val converter = newConverter(elementType, arrayDataUpdater) - value.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => - if (x == null) { - arrayDataUpdater.set(null) - } else { - converter.set(x) - } - } - arrayDataUpdater.end() + i += 1 + } + + updater.set(ordinal, result) + + case MapType(keyType, valueType, _) => (ordinal, value) => + val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + val length = orcMap.size() + val keyArray = createArrayData(keyType, length) + val keyUpdater = new ArrayDataUpdater(keyArray) + val keyConverter = newWriter(keyType, keyUpdater) + val valueArray = createArrayData(valueType, length) + val valueUpdater = new ArrayDataUpdater(valueArray) + val valueConverter = newWriter(valueType, valueUpdater) + + var i = 0 + val it = orcMap.entrySet().iterator() + while (it.hasNext) { + val entry = it.next() + keyConverter(i, entry.getKey) + val value = entry.getValue + if (value == null) { + valueArray.setNullAt(i) + } else { + valueConverter(i, value) } + i += 1 } - case MapType(keyType, valueType, _) => - new Converter { - override def set(value: Any): Unit = { - val mapDataUpdater = new MapDataUpdater(keyType, valueType, updater) - mapDataUpdater.set(value) - mapDataUpdater.end() - } - } + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) - case udt: UserDefinedType[_] => - new Converter { - override def set(value: Any): Unit = { - val mutableRow = new SpecificInternalRow(new StructType().add("_col1", udt.sqlType)) - val converter = newConverter(udt.sqlType, new RowUpdater(mutableRow, 0)) - converter.set(value) - updater.set(mutableRow.get(0, dataType)) - } - } + case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") } - - // -------------------------------------------------------------------------- - // Converter and Updaters - // -------------------------------------------------------------------------- - - trait Converter { - def set(value: Any): Unit + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) } - trait OrcDataUpdater { - def setNullAt(): Unit = () - - def set(value: Any): Unit = () - - def setBoolean(value: Boolean): Unit = set(value) - def setByte(value: Byte): Unit = set(value) - def setShort(value: Short): Unit = set(value) - def setInt(value: Int): Unit = set(value) - def setLong(value: Long): Unit = set(value) - def setDouble(value: Double): Unit = set(value) - def setFloat(value: Float): Unit = set(value) + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) } - final class RowUpdater(row: InternalRow, i: Int) extends OrcDataUpdater { - override def setNullAt(): Unit = row.setNullAt(i) - - override def set(value: Any): Unit = row(i) = value - - override def setBoolean(value: Boolean): Unit = row.setBoolean(i, value) - override def setByte(value: Byte): Unit = row.setByte(i, value) - override def setShort(value: Short): Unit = row.setShort(i, value) - override def setInt(value: Int): Unit = row.setInt(i, value) - override def setLong(value: Long): Unit = row.setLong(i, value) - override def setDouble(value: Double): Unit = row.setDouble(i, value) - override def setFloat(value: Float): Unit = row.setFloat(i, value) + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) } - final class ArrayDataUpdater(updater: OrcDataUpdater) extends OrcDataUpdater { - private val currentArray: ArrayBuffer[Any] = ArrayBuffer.empty[Any] - - override def set(value: Any): Unit = currentArray += value - - def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) - } - - final class MapDataUpdater( - keyType: DataType, - valueType: DataType, - updater: OrcDataUpdater) - extends OrcDataUpdater { - - private val currentKeys: ArrayBuffer[Any] = ArrayBuffer.empty[Any] - private val currentValues: ArrayBuffer[Any] = ArrayBuffer.empty[Any] - - private val keyConverter = newConverter(keyType, new OrcDataUpdater { - override def set(value: Any): Unit = currentKeys += value - }) - private val valueConverter = newConverter(valueType, new OrcDataUpdater { - override def set(value: Any): Unit = currentValues += value - }) - - override def set(value: Any): Unit = { - value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - .entrySet().asScala.foreach { entry => - keyConverter.set(entry.getKey) - if (entry.getValue == null) { - currentValues += null - } else { - valueConverter.set(entry.getValue) - } - } - } - - def end(): Unit = updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray)) + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 91e6cd88cc70..75c42213db3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -145,16 +145,17 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value - // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. - val (isEmptyFile, missingColumnNames) = OrcUtils.getMissingColumnNames( - isCaseSensitive, dataSchema, partitionSchema, new Path(new URI(file.filePath)), conf) - if (isEmptyFile) { + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + + if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty } else { - val columns = requiredSchema - .filter(f => !missingColumnNames.contains(f.name)) - .map(f => dataSchema.fieldIndex(f.name)).mkString(",") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, columns) + val requestedColIds = requestedColIdsOrEmptyFile.get + assert(requestedColIds.length == requiredSchema.length, + "[BUG] requested column IDs do not match required schema") + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + requestedColIds.filter(_ != -1).sorted.mkString(",")) val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) @@ -167,7 +168,7 @@ class OrcFileFormat Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) val unsafeProjection = UnsafeProjection.create(requiredSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, missingColumnNames) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) iter.map(value => unsafeProjection(deserializer.deserialize(value))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index bd30b7c31add..7d8ded2b57bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -18,188 +18,218 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.io._ +import org.apache.orc.TypeDescription import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.orc.OrcUtils.{getTypeDescription, withNullSafe} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -private[orc] class OrcSerializer(dataSchema: StructType) { - - private[this] lazy val orcStruct: OrcStruct = createOrcValue(dataSchema).asInstanceOf[OrcStruct] - - private[this] val length = dataSchema.length +/** + * A serializer to serialize Spark rows to ORC structs. + */ +class OrcSerializer(dataSchema: StructType) { - private[this] val writers = dataSchema.map(_.dataType).map(makeWriter).toArray + private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct] + private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray def serialize(row: InternalRow): OrcStruct = { var i = 0 - while (i < length) { + while (i < converters.length) { if (row.isNullAt(i)) { - orcStruct.setFieldValue(i, null) + result.setFieldValue(i, null) } else { - writers(i)(row, i) + result.setFieldValue(i, converters(i)(row, i)) } i += 1 } - orcStruct + result } - private[this] def makeWriter(dataType: DataType): (SpecializedGetters, Int) => Unit = { - dataType match { - case BooleanType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new BooleanWritable(row.getBoolean(ordinal))) - - case ByteType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new ByteWritable(row.getByte(ordinal))) - - case ShortType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new ShortWritable(row.getShort(ordinal))) - - case IntegerType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new IntWritable(row.getInt(ordinal))) - - case LongType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new LongWritable(row.getLong(ordinal))) - - case FloatType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new FloatWritable(row.getFloat(ordinal))) - - case DoubleType => - (row: SpecializedGetters, ordinal: Int) => - orcStruct.setFieldValue(ordinal, new DoubleWritable(row.getDouble(ordinal))) - - case _ => - val wrapper = getWritableWrapper(dataType) - (row: SpecializedGetters, ordinal: Int) => { - val value = wrapper(row.get(ordinal, dataType)).asInstanceOf[WritableComparable[_]] - orcStruct.setFieldValue(ordinal, value) - } - } - } + private type Converter = (SpecializedGetters, Int) => WritableComparable[_] /** - * Return a Orc value object for the given Spark schema. + * Creates a converter to convert Catalyst data at the given ordinal to ORC values. */ - private[this] def createOrcValue(dataType: DataType) = - OrcStruct.createValue(getTypeDescription(dataType)) + private def newConverter( + dataType: DataType, + reuseObj: Boolean = true): Converter = dataType match { + case NullType => (getter, ordinal) => null + + case BooleanType => + if (reuseObj) { + val result = new BooleanWritable() + (getter, ordinal) => + result.set(getter.getBoolean(ordinal)) + result + } else { + (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal)) + } - /** - * Convert Apache Spark InternalRow to Apache ORC OrcStruct. - */ - private[this] def convertInternalRowToOrcStruct(row: InternalRow, schema: StructType) = { - val wrappers = schema.map(_.dataType).map(getWritableWrapper).toArray - val orcStruct = createOrcValue(schema).asInstanceOf[OrcStruct] + case ByteType => + if (reuseObj) { + val result = new ByteWritable() + (getter, ordinal) => + result.set(getter.getByte(ordinal)) + result + } else { + (getter, ordinal) => new ByteWritable(getter.getByte(ordinal)) + } - var i = 0 - val length = schema.length - while (i < length) { - val fieldType = schema(i).dataType - if (row.isNullAt(i)) { - orcStruct.setFieldValue(i, null) + case ShortType => + if (reuseObj) { + val result = new ShortWritable() + (getter, ordinal) => + result.set(getter.getShort(ordinal)) + result } else { - val field = row.get(i, fieldType) - val fieldValue = wrappers(i)(field).asInstanceOf[WritableComparable[_]] - orcStruct.setFieldValue(i, fieldValue) + (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) } - i += 1 - } - orcStruct - } - /** - * Builds a WritableComparable-return function ahead of time according to DataType - * to avoid pattern matching and branching costs per row. - */ - private[this] def getWritableWrapper(dataType: DataType): Any => Any = dataType match { - case NullType => _ => null + case IntegerType => + if (reuseObj) { + val result = new IntWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new IntWritable(getter.getInt(ordinal)) + } - case BooleanType => withNullSafe(o => new BooleanWritable(o.asInstanceOf[Boolean])) - case ByteType => withNullSafe(o => new ByteWritable(o.asInstanceOf[Byte])) - case ShortType => withNullSafe(o => new ShortWritable(o.asInstanceOf[Short])) - case IntegerType => withNullSafe(o => new IntWritable(o.asInstanceOf[Int])) - case LongType => withNullSafe(o => new LongWritable(o.asInstanceOf[Long])) + case LongType => + if (reuseObj) { + val result = new LongWritable() + (getter, ordinal) => + result.set(getter.getLong(ordinal)) + result + } else { + (getter, ordinal) => new LongWritable(getter.getLong(ordinal)) + } + + case FloatType => + if (reuseObj) { + val result = new FloatWritable() + (getter, ordinal) => + result.set(getter.getFloat(ordinal)) + result + } else { + (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal)) + } - case FloatType => withNullSafe(o => new FloatWritable(o.asInstanceOf[Float])) - case DoubleType => withNullSafe(o => new DoubleWritable(o.asInstanceOf[Double])) + case DoubleType => + if (reuseObj) { + val result = new DoubleWritable() + (getter, ordinal) => + result.set(getter.getDouble(ordinal)) + result + } else { + (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal)) + } - case StringType => withNullSafe(o => new Text(o.asInstanceOf[UTF8String].getBytes)) - case BinaryType => withNullSafe(o => new BytesWritable(o.asInstanceOf[Array[Byte]])) + // Don't reuse the result object for string and binary as it would cause extra data copy. + case StringType => (getter, ordinal) => + new Text(getter.getUTF8String(ordinal).getBytes) + + case BinaryType => (getter, ordinal) => + new BytesWritable(getter.getBinary(ordinal)) case DateType => - withNullSafe(o => new DateWritable(DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))) - case TimestampType => - withNullSafe { o => - val us = o.asInstanceOf[Long] - var seconds = us / DateTimeUtils.MICROS_PER_SECOND - var micros = us % DateTimeUtils.MICROS_PER_SECOND - if (micros < 0) { - micros += DateTimeUtils.MICROS_PER_SECOND - seconds -= 1 - } - val t = new OrcTimestamp(seconds * 1000) - t.setNanos(micros.toInt * 1000) - t + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) } - case _: DecimalType => - withNullSafe { o => - new HiveDecimalWritable(HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) - } + // Already expensive, reusing object or not doesn't matter. + case TimestampType => + (getter, ordinal) => + val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) + val result = new OrcTimestamp(ts.getTime) + result.setNanos(ts.getNanos) + result + + // Already expensive, reusing object or not doesn't matter. + case DecimalType.Fixed(precision, scale) => + (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) case st: StructType => - withNullSafe(o => convertInternalRowToOrcStruct(o.asInstanceOf[InternalRow], st)) - - case ArrayType(et, _) => - withNullSafe { o => - val data = o.asInstanceOf[ArrayData] - val list = createOrcValue(dataType) - for (i <- 0 until data.numElements()) { - val d = data.get(i, et) - val v = getWritableWrapper(et)(d).asInstanceOf[WritableComparable[_]] - list.asInstanceOf[OrcList[WritableComparable[_]]].add(v) + val result = createOrcValue(st).asInstanceOf[OrcStruct] + val fieldConverters = st.map(_.dataType).map(newConverter(_)) + val numFields = st.length + (getter, ordinal) => + val struct = getter.getStruct(ordinal, numFields) + var i = 0 + while (i < numFields) { + if (struct.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, fieldConverters(i)(struct, i)) + } + i += 1 } - list - } + result + + case ArrayType(elementType, _) => + val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val elementConverter = newConverter(elementType, reuseObj = false) + (getter, ordinal) => + result.clear() + val array = getter.getArray(ordinal) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(array, i)) + } + i += 1 + } + result case MapType(keyType, valueType, _) => - withNullSafe { o => - val keyWrapper = getWritableWrapper(keyType) - val valueWrapper = getWritableWrapper(valueType) - val data = o.asInstanceOf[MapData] - val map = createOrcValue(dataType) - .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - data.foreach(keyType, valueType, { case (k, v) => - map.put( - keyWrapper(k).asInstanceOf[WritableComparable[_]], - valueWrapper(v).asInstanceOf[WritableComparable[_]]) - }) - map - } + val result = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val keyConverter = newConverter(keyType, reuseObj = false) + val valueConverter = newConverter(valueType, reuseObj = false) + (getter, ordinal) => + result.clear() + val map = getter.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + val key = keyConverter(keyArray, i) + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result - case udt: UserDefinedType[_] => - withNullSafe { o => - val udtRow = new SpecificInternalRow(Seq(udt.sqlType)) - udtRow(0) = o - convertInternalRowToOrcStruct( - udtRow, - StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0) - } + case udt: UserDefinedType[_] => newConverter(udt.sqlType) case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") } + + /** + * Return a Orc value object for the given Spark schema. + */ + private def createOrcValue(dataType: DataType) = { + OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index abcd321feee2..f16db57e5e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.io.IOException - import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -42,10 +39,6 @@ object OrcUtils extends Logging { "ZLIB" -> ".zlib", "LZO" -> ".lzo") - def withNullSafe(f: Any => Any): Any => Any = { - input => if (input == null) null else f(input) - } - def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) @@ -57,7 +50,7 @@ object OrcUtils extends Logging { paths } - private[orc] def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val reader = OrcFile.createReader(file, readerOptions) @@ -69,7 +62,7 @@ object OrcUtils extends Logging { } } - private[orc] def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) + def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. @@ -79,50 +72,35 @@ object OrcUtils extends Logging { } } - private[orc] def getTypeDescription(dataType: DataType) = dataType match { - case st: StructType => TypeDescription.fromString(st.catalogString) - case _ => TypeDescription.fromString(dataType.catalogString) - } - /** - * Return a pair of isEmptyFile and missing column names in a give ORC file. - * Some old empty ORC files always have an empty schema stored in their footer. (SPARK-8501) - * In that case, isEmptyFile is `true` and missing column names is `None`. + * Returns the requested column ids from the given ORC file. Column id can be -1, which means the + * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty. */ - private[orc] def getMissingColumnNames( + def requestedColumnIds( isCaseSensitive: Boolean, dataSchema: StructType, - partitionSchema: StructType, + requiredSchema: StructType, file: Path, - conf: Configuration): (Boolean, Seq[String]) = { - val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution + conf: Configuration): Option[Array[Int]] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - (true, Seq.empty[String]) + val orcFieldNames = reader.getSchema.getFieldNames.asScala + if (orcFieldNames.isEmpty) { + // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. + None } else { - val orcSchema = if (schema.getFieldNames.asScala.forall(_.startsWith("_col"))) { - logInfo("Recover ORC schema with data schema") - var schemaString = schema.toString - dataSchema.zipWithIndex.foreach { case (field: StructField, index: Int) => - schemaString = schemaString.replace(s"_col$index:", s"${field.name}:") - } - TypeDescription.fromString(schemaString) + if (orcFieldNames.forall(_.startsWith("_col"))) { + // This is a ORC file written by Hive, no field names in the physical schema, assume the + // physical schema maps to the data scheme by index. + assert(orcFieldNames.length == dataSchema.length, "The given data schema " + + s"${dataSchema.simpleString} doesn't match the actual ORC physical schema " + + orcFieldNames.mkString(", ")) + Some(requiredSchema.fieldNames.map { name => dataSchema.fieldIndex(name) }) } else { - schema - } - - val missingColumnNames = new ArrayBuffer[String] - if (dataSchema.length > orcSchema.getFieldNames.size) { - dataSchema.filter(x => partitionSchema.getFieldIndex(x.name).isEmpty).foreach { f => - if (!orcSchema.getFieldNames.asScala.exists(resolver(_, f.name))) { - missingColumnNames += f.name - } - } + val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution + Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) } - (false, missingColumnNames) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7e6f89932b74..4dec2f71b8a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.orc.OrcUtils.withNullSafe import org.apache.spark.sql.types import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -257,6 +256,10 @@ private[hive] trait HiveInspectors { case _ => false } + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + /** * Wraps with Hive types based on object inspector. */ From f143e1744ecef2203757e57f3aa882832ee4c37a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 2 Dec 2017 01:36:50 +0800 Subject: [PATCH 17/20] address comment --- .../sql/execution/datasources/orc/OrcUtils.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index f16db57e5e44..b03ee06d04a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -93,10 +93,17 @@ object OrcUtils extends Logging { if (orcFieldNames.forall(_.startsWith("_col"))) { // This is a ORC file written by Hive, no field names in the physical schema, assume the // physical schema maps to the data scheme by index. - assert(orcFieldNames.length == dataSchema.length, "The given data schema " + - s"${dataSchema.simpleString} doesn't match the actual ORC physical schema " + - orcFieldNames.mkString(", ")) - Some(requiredSchema.fieldNames.map { name => dataSchema.fieldIndex(name) }) + assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + + s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + "no idea which columns were dropped, fail to read.") + Some(requiredSchema.fieldNames.map { name => + val index = dataSchema.fieldIndex(name) + if (index < orcFieldNames.length) { + index + } else { + -1 + } + }) } else { val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) From 8a34731ebfe946ec48bf1353f675cc3b7c987ae2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 1 Dec 2017 09:39:38 -0800 Subject: [PATCH 18/20] Revert the change on HiveInspectors.scala. --- .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7e6f89932b74..4dec2f71b8a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.orc.OrcUtils.withNullSafe import org.apache.spark.sql.types import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -257,6 +256,10 @@ private[hive] trait HiveInspectors { case _ => false } + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + /** * Wraps with Hive types based on object inspector. */ From eae50b3e776bb64d2927a8c8edb43562983e25fb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 2 Dec 2017 02:02:41 +0800 Subject: [PATCH 19/20] bug fix --- .../datasources/orc/OrcSerializer.scala | 103 ++++++++---------- 1 file changed, 48 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 7d8ded2b57bb..899af0750cad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -148,77 +148,70 @@ class OrcSerializer(dataSchema: StructType) { (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) } - // Already expensive, reusing object or not doesn't matter. - case TimestampType => - (getter, ordinal) => - val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) - val result = new OrcTimestamp(ts.getTime) - result.setNanos(ts.getNanos) - result - - // Already expensive, reusing object or not doesn't matter. - case DecimalType.Fixed(precision, scale) => - (getter, ordinal) => - val d = getter.getDecimal(ordinal, precision, scale) - new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) - - case st: StructType => + // The following cases are already expensive, reusing object or not doesn't matter. + + case TimestampType => (getter, ordinal) => + val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) + val result = new OrcTimestamp(ts.getTime) + result.setNanos(ts.getNanos) + result + + case DecimalType.Fixed(precision, scale) => (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + + case st: StructType => (getter, ordinal) => val result = createOrcValue(st).asInstanceOf[OrcStruct] val fieldConverters = st.map(_.dataType).map(newConverter(_)) val numFields = st.length - (getter, ordinal) => - val struct = getter.getStruct(ordinal, numFields) - var i = 0 - while (i < numFields) { - if (struct.isNullAt(i)) { - result.setFieldValue(i, null) - } else { - result.setFieldValue(i, fieldConverters(i)(struct, i)) - } - i += 1 + val struct = getter.getStruct(ordinal, numFields) + var i = 0 + while (i < numFields) { + if (struct.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, fieldConverters(i)(struct, i)) } - result + i += 1 + } + result - case ArrayType(elementType, _) => + case ArrayType(elementType, _) => (getter, ordinal) => val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] // Need to put all converted values to a list, can't reuse object. val elementConverter = newConverter(elementType, reuseObj = false) - (getter, ordinal) => - result.clear() - val array = getter.getArray(ordinal) - var i = 0 - while (i < array.numElements()) { - if (array.isNullAt(i)) { - result.add(null) - } else { - result.add(elementConverter(array, i)) - } - i += 1 + val array = getter.getArray(ordinal) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(array, i)) } - result + i += 1 + } + result - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _) => (getter, ordinal) => val result = createOrcValue(dataType) .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] // Need to put all converted values to a list, can't reuse object. val keyConverter = newConverter(keyType, reuseObj = false) val valueConverter = newConverter(valueType, reuseObj = false) - (getter, ordinal) => - result.clear() - val map = getter.getMap(ordinal) - val keyArray = map.keyArray() - val valueArray = map.valueArray() - var i = 0 - while (i < map.numElements()) { - val key = keyConverter(keyArray, i) - if (valueArray.isNullAt(i)) { - result.put(key, null) - } else { - result.put(key, valueConverter(valueArray, i)) - } - i += 1 + val map = getter.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + val key = keyConverter(keyArray, i) + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) } - result + i += 1 + } + result case udt: UserDefinedType[_] => newConverter(udt.sqlType) From 71be008cc887a1fa53a20ada0417a14e03a4ae89 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Dec 2017 18:14:49 -0800 Subject: [PATCH 20/20] Restore old ORC implementation. --- ...pache.spark.sql.sources.DataSourceRegister | 1 - .../execution/datasources/DataSource.scala | 3 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 ++++++++++++++++-- .../sql/sources/DDLSourceLoadSuite.scala | 7 +++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/sql/hive/HiveStrategies.scala | 2 +- .../spark/sql/hive/orc/OrcFileOperator.scala | 20 +++++++++++-------- .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 - .../spark/sql/hive/orc/OrcQuerySuite.scala | 17 ++++------------ .../spark/sql/hive/orc/OrcSourceSuite.scala | 16 +++------------ 10 files changed, 46 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae564..0c5f3f22e31e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,7 +1,6 @@ org.apache.spark.sql.execution.datasources.csv.CSVFileFormat org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat -org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fdf113d36b3b..b43d282bd434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ @@ -537,7 +536,7 @@ object DataSource extends Logging { val parquet = classOf[ParquetFileFormat].getCanonicalName val csv = classOf[CSVFileFormat].getCanonicalName val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" - val orc = classOf[OrcFileFormat].getCanonicalName + val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" Map( "org.apache.spark.sql.jdbc" -> jdbc, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9f4469a09ddb..31d9b909ad46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1661,6 +1661,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Path does not exist")) + e = intercept[AnalysisException] { + sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") + } + assert(e.message.contains("The ORC data source must be used with Hive support enabled")) + e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") } @@ -2753,8 +2758,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - // Only New OrcFileFormat supports this. - Seq("orc", "parquet").foreach { format => + // Only New OrcFileFormat supports this + Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, + "parquet").foreach { format => test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { withTempPath { file => val path = file.getCanonicalPath @@ -2767,4 +2773,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21791 ORC should support column names with dot") { + val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path) + assert(spark.read.format(orc).load(path).collect().length == 2) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index f22d843bfabd..3ce6ae3c5292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -53,6 +53,13 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))) } + + test("should fail to load ORC without Hive Support") { + val e = intercept[AnalysisException] { + spark.read.format("orc").load() + } + assert(e.message.contains("The ORC data source must be used with Hive support enabled")) + } } diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index d73a2e5dbeae..e7d762fbebe7 100644 --- a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1,2 @@ +org.apache.spark.sql.hive.orc.OrcFileFormat org.apache.spark.sql.hive.execution.HiveFileFormat \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 7c41cba33623..ee1f6ee17306 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index aa0be0630dc3..5a3fcd7a759c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -22,10 +22,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.execution.datasources.orc.OrcUtils -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.types.StructType private[hive] object OrcFileOperator extends Logging { @@ -65,7 +64,7 @@ private[hive] object OrcFileOperator extends Logging { hdfsPath.getFileSystem(conf) } - OrcUtils.listOrcFiles(basePath, conf).iterator.map { path => + listOrcFiles(basePath, conf).iterator.map { path => path -> OrcFile.createReader(fs, path) }.collectFirst { case (path, reader) if isWithNonEmptySchema(path, reader) => reader @@ -88,10 +87,15 @@ private[hive] object OrcFileOperator extends Logging { getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) } - def setRequiredColumns( - conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) - val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + // TODO: Check if the paths coming in are already qualified and simplify. + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDirectory) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + paths } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index c7c1264dca5d..ba0a7605da71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -23,7 +23,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.catalog.CatalogUtils -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 021c8c495854..1fa9091f967a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.hive.orc -import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll @@ -61,14 +59,6 @@ case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { - private def getFileReader(path: String, extensions: String) = { - import org.apache.orc.OrcFile - val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(extensions)) - assert(maybeOrcFile.isDefined) - val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) - OrcFile.createReader(orcFilePath, OrcFile.readerOptions(new Configuration())) - } - test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) @@ -240,13 +230,14 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("LZO compression options for writing to an ORC file") { + // Following codec is not supported in Hive 1.2.1, ignore it now + ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { withTempPath { file => spark.range(0, 10).write .option("compression", "LZO") .orc(file.getCanonicalPath) val expectedCompressionKind = - getFileReader(file.getAbsolutePath, ".lzo.orc").getCompressionKind + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression assert("LZO" === expectedCompressionKind.name()) } } @@ -608,7 +599,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcFileOperator.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 1f8ee0becbcb..2a086be57f51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -226,13 +225,13 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } } -class OrcSourceSuite extends OrcSuite with SQLTestUtils { +class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_source - |USING orc + |USING org.apache.spark.sql.hive.orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -240,7 +239,7 @@ class OrcSourceSuite extends OrcSuite with SQLTestUtils { spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_as_source - |USING orc + |USING org.apache.spark.sql.hive.orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -279,13 +278,4 @@ class OrcSourceSuite extends OrcSuite with SQLTestUtils { )).get.toString } } - - test("SPARK-21791 ORC should support column names with dot") { - import spark.implicits._ - withTempDir { dir => - val path = new File(dir, "orc").getCanonicalPath - Seq(Some(1), None).toDF("col.dots").write.orc(path) - assert(spark.read.orc(path).collect().length == 2) - } - } }