diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 21fc1224eff2..fa72bc88937e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -137,7 +137,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes - val joinedRow = new JoinedRow() file => { val lines = new HadoopFileLinesReader(file, broadcastedConf.value.value).map(_.toString) @@ -148,10 +147,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { columnNameOfCorruptRecord, parsedOptions) - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - rows.map { row => - appendPartitionColumns(joinedRow(row, file.partitionValues)) - } + FileFormat.appendPartitionValues(rows, fullSchema, file.partitionValues) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index d6b84be26741..3a31241e8609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -395,14 +395,13 @@ private[sql] class DefaultSource iter.asInstanceOf[Iterator[InternalRow]] } else { val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes - val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - - // This is a horrible erasure hack... if we type the iterator above, then it actually check - // the type in next() and we get a class cast exception. If we make that function return - // Object, then we can defer the cast until later! - iter.asInstanceOf[Iterator[InternalRow]] - .map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) + FileFormat.appendPartitionValues( + // This is a horrible erasure hack... if we type the iterator above, then it actually + // check the type in next() and we get a class cast exception. If we make that function + // return Object, then we can defer the cast until later! + iter.asInstanceOf[Iterator[InternalRow]], + fullSchema, + file.partitionValues) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 1e02354edf4c..b6c511a6fa80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{Sink, Source} @@ -484,6 +485,17 @@ trait FileFormat { } } +private[sql] object FileFormat { + def appendPartitionValues( + rows: Iterator[InternalRow], + output: Seq[Attribute], + partitionValues: InternalRow): Iterator[InternalRow] = { + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(output, output) + rows.map { row => appendPartitionColumns(joinedRow(row, partitionValues)) } + } +} + /** * A collection of data files from a partitioned relation, along with the partition values in the * form of an [[InternalRow]]. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c4a0a0c0f09..004d4311f4d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -174,14 +174,10 @@ private[sql] class DefaultSource file.filePath, conf, dataSchema, new RecordReaderIterator[OrcStruct](orcRecordReader) ) - // Appends partition values - val fullOutput = dataSchema.toAttributes ++ partitionSchema.toAttributes - val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) - - unsafeRowIterator.map { dataRow => - appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) - } + FileFormat.appendPartitionValues( + unsafeRowIterator, + dataSchema.toAttributes ++ partitionSchema.toAttributes, + file.partitionValues) } } }