diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea16..32bc016ce504 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -41,20 +41,20 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with SparkHadoopMapRedUtil with Serializable { - private val now = new Date() - private val conf = new SerializableWritable(jobConf) + protected val now = new Date() + protected val conf = new SerializableWritable(jobConf) - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null + protected var jobID = 0 + protected var splitID = 0 + protected var attemptID = 0 + protected var jID: SerializableWritable[JobID] = null + protected var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null + @transient protected var writer: RecordWriter[AnyRef,AnyRef] = null + @transient protected var format: OutputFormat[AnyRef,AnyRef] = null + @transient protected var committer: OutputCommitter = null + @transient protected var jobContext: JobContext = null + @transient protected var taskContext: TaskAttemptContext = null def preSetup() { setIDs(0, 0, 0) @@ -112,9 +112,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) cmtr.commitJob(getJobContext()) } - // ********* Private Functions ********* - - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() .asInstanceOf[OutputFormat[AnyRef,AnyRef]] @@ -122,28 +120,28 @@ class SparkHadoopWriter(@transient jobConf: JobConf) format } - private def getOutputCommitter(): OutputCommitter = { + def getOutputCommitter(): OutputCommitter = { if (committer == null) { committer = conf.value.getOutputCommitter } committer } - private def getJobContext(): JobContext = { + def getJobContext(): JobContext = { if (jobContext == null) { jobContext = newJobContext(conf.value, jID.value) } jobContext } - private def getTaskContext(): TaskAttemptContext = { + def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { + def setIDs(jobid: Int, splitid: Int, attemptid: Int) { jobID = jobid splitID = splitid attemptID = attemptid diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e556c74ffb01..5c205277404e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -330,7 +330,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { } val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => - soi.getStructFieldRef(attr.name) -> ordinal + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal }.unzip /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/orc.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/orc.scala new file mode 100644 index 000000000000..75674628f6a3 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/orc.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.IOException +import java.util.{Locale, Properties} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.mapred.{JobConf, InputFormat, FileInputFormat} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.{SparkContext, SparkHadoopWriter, SerializableWritable, Logging} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive._ +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption + + +/** + * Allows creation of orc based tables using the syntax + * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.orc`. + * Currently the only option required is `path`, which should be the location of a collection of, + * optionally partitioned, orc files. + */ +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for orc tables.")) + } + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + OrcRelation(checkPath(parameters), parameters, None)(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + OrcRelation(checkPath(parameters), parameters, Some(schema))(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => + fs.delete(filesystemPath, true) + true + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + + val relation = if (doSave) { + val createdRelation = createRelation(sqlContext, parameters, data.schema) + createdRelation.asInstanceOf[OrcRelation].insert(data, true) + createdRelation + } else { + // If the save mode is Ignore, we will just create the relation based on existing data. + createRelation(sqlContext, parameters) + } + + relation + } +} + +@DeveloperApi +case class OrcRelation + (path: String, parameters: Map[String, String], maybeSchema: Option[StructType] = None) + (@transient val sqlContext: SQLContext) + extends BaseRelation + with CatalystScan + with InsertableRelation + with SparkHadoopMapRedUtil + with HiveInspectors + with Logging { + + def sparkContext: SparkContext = sqlContext.sparkContext + + // todo: Should calculate per scan size + override def sizeInBytes: Long = { + val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) + val fileStatus = fs.getFileStatus(fs.makeQualified(new Path(path))) + val leaves = SparkHadoopUtil.get.listLeafStatuses(fs, fileStatus.getPath).filter { f => + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } + leaves.map(_.getLen).sum + } + + private def initialColumnsNamesTypes(schema: StructType) = { + val inspector = toInspector(schema).asInstanceOf[StructObjectInspector] + val fields = inspector.getAllStructFieldRefs + val (columns, columnTypes) = fields.map { f => + f.getFieldName -> f.getFieldObjectInspector.getTypeName + }.unzip + val columnsNames = columns.mkString(",") + val columnsTypes = columnTypes.mkString(":") + (columnsNames, columnsTypes) + } + + private def orcSchema( + path: Path, + configuration: Option[Configuration]): StructType = { + // get the schema info through ORC Reader + val conf = configuration.getOrElse(new Configuration()) + val fs: FileSystem = path.getFileSystem(conf) + val reader = OrcFile.createReader(fs, path) + require(reader != null, "metadata reader is null!") + if (reader == null) { + // return empty seq when saveAsOrcFile + return StructType(Seq.empty) + } + val inspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + // data types that is inspected by this inspector + val schema = inspector.getTypeName + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + lazy val schema = { + val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) + val childrenOfPath = fs.listStatus(new Path(path)) + .filterNot(_.getPath.getName.startsWith("_")) + .filterNot(_.isDir) + maybeSchema.getOrElse(orcSchema( + childrenOfPath.head.getPath, + Some(sparkContext.hadoopConfiguration))) + } + + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + val sc = sparkContext + val conf: Configuration = sc.hadoopConfiguration + + val setInputPathsFunc: Option[JobConf => Unit] = + Some((jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)) + + addColumnIds(output, schema.toAttributes, conf) + val inputClass = + classOf[OrcInputFormat].asInstanceOf[Class[_ <: InputFormat[NullWritable, Writable]]] + + // use SpecificMutableRow to decrease GC garbage + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) + val attrsWithIndex = output.zipWithIndex + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + val (columnsNames, columnsTypes) = initialColumnsNamesTypes(schema) + val rowRdd = + new HadoopRDD( + sc, + confBroadcast, + setInputPathsFunc, + inputClass, + classOf[NullWritable], + classOf[Writable], + sc.defaultMinPartitions).mapPartitionsWithInputSplit { (split, iter) => + + val deserializer = { + val prop: Properties = new Properties + prop.setProperty("columns", columnsNames) + prop.setProperty("columns.types", columnsTypes) + + val serde = new OrcSerde + serde.initialize(null, prop) + serde + } + HadoopTableReader.fillObject( + iter.map(_._2), + deserializer, + attrsWithIndex, + mutableRow, + deserializer) + } + rowRdd + } + + /** + * add column ids and names + * @param output + * @param relationOutput + * @param conf + */ + private def addColumnIds( + output: Seq[Attribute], + relationOutput: Seq[Attribute], + conf: Configuration) { + val names = output.map(_.name) + val fieldIdMap = relationOutput.map(_.name.toLowerCase(Locale.ENGLISH)).zipWithIndex.toMap + val ids = output.map { att => + val realName = att.name.toLowerCase(Locale.ENGLISH) + fieldIdMap.getOrElse(realName, -1) + }.filter(_ >= 0).map(_.asInstanceOf[Integer]) + + assert(ids.size == output.size, "columns id and name length does not match!") + if (ids != null && !ids.isEmpty) { + HiveShim.appendReadColumns(conf, ids, names) + } + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + import org.apache.hadoop.mapred.{FileOutputFormat, FileOutputCommitter} + import org.apache.spark.TaskContext + + val (columnsNames, columnsTypes) = initialColumnsNamesTypes(data.schema) + @transient val job = new JobConf(sqlContext.sparkContext.hadoopConfiguration) + job.setOutputKeyClass(classOf[NullWritable]) + job.setOutputValueClass(classOf[Row]) + job.set("mapred.output.format.class", classOf[OrcOutputFormat].getName) + job.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(job, SparkHadoopWriter.createPathFromString(path, job)) + + val conf = new Configuration(job) + val destinationPath = new Path(path) + if (overwrite) { + try { + destinationPath.getFileSystem(conf).delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Orc file:\n${e.toString}") + } + } + + val taskIdOffset = if (overwrite) { + 1 + } else { + FileSystemHelper.findMaxTaskId( + FileOutputFormat.getOutputPath(job).toString, conf) + 1 + } + + val writer = new OrcHadoopWriter(job) + writer.preSetup() + sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) + writer.commitJob() + + // this function is executed on executor side + def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { + val nullWritable = NullWritable.get() + val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt + + val serializer = { + val prop: Properties = new Properties + prop.setProperty("columns", columnsNames) + prop.setProperty("columns.types", columnsTypes) + val serde = new OrcSerde + serde.initialize(null, prop) + serde + } + + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector(serializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) + val outputData = new Array[Any](fieldOIs.length) + + writer.setup(context.stageId, context.partitionId + taskIdOffset, taskAttemptId) + writer.open() + var row: Row = null + var i = 0 + try { + while (iterator.hasNext) { + row = iterator.next() + i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + i += 1 + } + writer.write(nullWritable, serializer.serialize(outputData, standardOI)) + } + } finally { + writer.close() + } + writer.commit() + } + } +} + +private[orc] object FileSystemHelper { + def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"OrcTableOperations: Path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"OrcTableOperations: path $path does not exist or is not a directory") + } + fs.globStatus(path) + .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .map(_.getPath) + } + + /** + * Finds the maximum taskid in the output file names at the given path. + */ + def findMaxTaskId(pathStr: String, conf: Configuration): Int = { + val files = FileSystemHelper.listFiles(pathStr, conf) + // filename pattern is part- + val nameP = new scala.util.matching.Regex("""part-(\d{1,})""", "taskid") + val hiddenFileP = new scala.util.matching.Regex("_.*") + files.map(_.getName).map { + case nameP(taskid) => taskid.toInt + case hiddenFileP() => 0 + case other: String => + sys.error("ERROR: attempting to append to set of Orc files and found file" + + s"that does not match name pattern: $other") + case _ => 0 + }.reduceLeft((a, b) => if (a < b) b else a) + } +} + +class OrcHadoopWriter(@transient jobConf: JobConf) extends SparkHadoopWriter(jobConf) { + import java.text.NumberFormat + import org.apache.hadoop.mapred._ + + override def open() { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + + val outputName = "part-" + numfmt.format(splitID) + val path = FileOutputFormat.getOutputPath(conf.value) + val fs: FileSystem = { + if (path != null) { + path.getFileSystem(conf.value) + } else { + FileSystem.get(conf.value) + } + } + + // get the path of the temporary output file + val name = FileOutputFormat.getTaskOutputPath(conf.value, outputName).toString; + + getOutputCommitter().setupTask(getTaskContext()) + writer = getOutputFormat().getRecordWriter(fs, conf.value, name, Reporter.NULL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala new file mode 100644 index 000000000000..5cc75e67ad98 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcTest extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + (sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i"))) + .toDF() + .registerTempTable(s"orc_temp_table") + + sql(s""" + create external table normal_orc + ( + intField INT, + stringField STRING + ) + STORED AS orc + location '${orcTableDir.getCanonicalPath}' + """) + + sql( + s"""insert into table normal_orc + select intField, stringField from orc_temp_table""") + + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("appending insert") { + sql("insert into table normal_orc_source select * from orc_temp_table where intField > 5") + checkAnswer( + sql("select * from normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(9, "part-9") :: + Row(10, "part-10") :: + Row(10, "part-10") :: Nil + ) + } + + test("overwrite insert") { + sql( + """ + |insert overwrite table normal_orc_as_source + |select * from orc_temp_table where intField > 5 + """.stripMargin) + checkAnswer( + sql("select * from normal_orc_as_source"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + } +} + +class OrcSourceSuite extends OrcTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table normal_orc_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableDir.getAbsolutePath).getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_orc_as_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + ) + as select * from orc_temp_table + """) + } +}